Unverified Commit ba291dad authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Improve the handling of quantized weights (#2250)

* Improve the handling of quantized weights

Handling of quantized weights was split between two mechanisms:

- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.

Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.

This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:

- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.

* Exclude non-MLP layers when using FP8 quantization with Llama
parent 1d1b1efa
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -0.6645508,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -2.2324219,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 6088,
"logprob": -1.6074219,
"special": false,
"text": " parse"
},
{
"id": 1243,
"logprob": -1.6298828,
"special": false,
"text": " test"
},
{
"id": 1206,
"logprob": -0.72558594,
"special": false,
"text": " case"
},
{
"id": 1024,
"logprob": -0.40429688,
"special": false,
"text": " name"
},
{
"id": 515,
"logprob": 0.0,
"special": false,
"text": " from"
},
{
"id": 525,
"logprob": -1.2519531,
"special": false,
"text": " '"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not parse test case name from '"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}
]
...@@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream( ...@@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
chunk = [c.replace("data:", "") for c in chunk] chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings # remove empty strings
chunk = [c for c in chunk if c] chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json # parse json
chunk = [json.loads(c) for c in chunk] chunk = [json.loads(c) for c in chunk]
......
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin24_handle(launcher):
with launcher(
"nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin24_handle):
await flash_llama_marlin24_handle.health(300)
return flash_llama_marlin24_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import torch import torch
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
Weights, Weights,
WeightsLoader, WeightsLoader,
) )
...@@ -363,7 +364,10 @@ class MockWeights(Weights): ...@@ -363,7 +364,10 @@ class MockWeights(Weights):
self.process_group = process_group self.process_group = process_group
self.prefix = prefix self.prefix = prefix
self.weights_loader = ( self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader # We don't need to get linear layers, so just wrap raw tensors.
DefaultWeightsLoader(lambda x: x)
if weights_loader is None
else weights_loader
) )
self._handles = {} self._handles = {}
...@@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): ...@@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
...@@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): ...@@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): ...@@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
...@@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): ...@@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): ...@@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
...@@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): ...@@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): ...@@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
...@@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): ...@@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): ...@@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
...@@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): ...@@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): ...@@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
...@@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): ...@@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): ...@@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
g_idx=None, g_idx=None,
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=True,
use_exllama=False, use_exllama=False,
) )
...@@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): ...@@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
...@@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): ...@@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0, bits=8.0,
groupsize=2.0, groupsize=2.0,
use_awq_kernel=False,
use_exllama=False, use_exllama=False,
) )
...@@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): ...@@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch" assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......
import torch from dataclasses import dataclass
from loguru import logger
from functools import lru_cache from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Int8Params, Params4bit from bitsandbytes.nn import Int8Params, Params4bit
from loguru import logger
from text_generation_server.utils.weights import Weight
@lru_cache(1) @lru_cache(1)
...@@ -12,6 +15,14 @@ def warn_deprecate_bnb(): ...@@ -12,6 +15,14 @@ def warn_deprecate_bnb():
) )
@dataclass
class BNBWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
class Linear8bitLt(torch.nn.Module): class Linear8bitLt(torch.nn.Module):
def __init__( def __init__(
self, self,
...@@ -70,6 +81,22 @@ class Linear8bitLt(torch.nn.Module): ...@@ -70,6 +81,22 @@ class Linear8bitLt(torch.nn.Module):
return out return out
@dataclass
class BNBFP4Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="fp4")
@dataclass
class BNBNF4Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="nf4")
class Linear4bit(torch.nn.Module): class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type): def __init__(self, weight, bias, quant_type):
super().__init__() super().__init__()
......
from dataclasses import dataclass
import torch import torch
from EETQ import quant_weights, w8_a16_gemm from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import Weight
@dataclass
class EETQWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module): class EETQLinear(torch.nn.Module):
......
import torch
from typing import List, Union
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union
from text_generation_server.utils.weights import WeightsLoader, Weights import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass @dataclass
class Exl2Weight: class Exl2Weight(Weight):
""" """
Exllama2 exl2 quantized weights. Exllama2 exl2 quantized weights.
""" """
...@@ -25,6 +25,11 @@ class Exl2Weight: ...@@ -25,6 +25,11 @@ class Exl2Weight:
def device(self) -> torch.device: def device(self) -> torch.device:
return self.q_weight.device return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader): class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights.""" """Loader for exl2-quantized weights."""
......
from enum import Enum, auto from dataclasses import dataclass
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight
def get_fp8_linear() -> torch.nn.Module: def get_fp8_linear() -> torch.nn.Module:
...@@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): ...@@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
return qweight, scale return qweight, scale
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return get_fp8_linear()(self.weight, bias)
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
def __init__( def __init__(
self, self,
......
from dataclasses import dataclass
from loguru import logger
import os import os
from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch import torch
from text_generation_server.utils.import_utils import ( from loguru import logger
SYSTEM, from text_generation_server.utils.import_utils import SYSTEM
)
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass @dataclass
class GPTQWeight: class GPTQWeight(Weight):
qweight: torch.Tensor qweight: torch.Tensor
qzeros: torch.Tensor qzeros: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
g_idx: Optional[torch.Tensor] g_idx: Optional[torch.Tensor]
bits: int bits: int
groupsize: int groupsize: int
use_awq_kernel: bool
use_exllama: bool use_exllama: bool
def __post_init__(self): def __post_init__(self):
...@@ -29,6 +28,50 @@ class GPTQWeight: ...@@ -29,6 +28,50 @@ class GPTQWeight:
def device(self) -> torch.device: def device(self) -> torch.device:
return self.qweight.device return self.qweight.device
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
return WQLinear(
w_bit=self.bits,
group_size=self.groupsize,
qweight=self.qweight,
qzeros=self.qzeros,
scales=self.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
bias,
self.bits,
self.groupsize,
)
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
...@@ -45,6 +88,8 @@ elif CAN_EXLLAMA: ...@@ -45,6 +88,8 @@ elif CAN_EXLLAMA:
if V2: if V2:
from text_generation_server.layers.gptq.exllamav2 import ( from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
...@@ -53,6 +98,8 @@ elif CAN_EXLLAMA: ...@@ -53,6 +98,8 @@ elif CAN_EXLLAMA:
else: else:
from text_generation_server.layers.gptq.exllama import ( from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
...@@ -162,6 +209,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -162,6 +209,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=False, use_exllama=False,
) )
...@@ -255,6 +303,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -255,6 +303,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama, use_exllama=use_exllama,
) )
...@@ -336,8 +385,8 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -336,8 +385,8 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False use_exllama = False
from text_generation_server.layers.gptq import ( from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA, CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight, GPTQWeight,
) )
...@@ -389,6 +438,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -389,6 +438,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama, use_exllama=use_exllama,
) )
......
from typing import Optional from typing import Optional
import torch import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
...@@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module): ...@@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module):
return F.linear(inp, self.weight, self.bias) return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias, quantize): def get_linear(weight, bias):
if quantize is None: # Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
if SYSTEM == "rocm": if SYSTEM == "rocm":
linear = FastLinearROCm(weight, bias) return FastLinearROCm(weight, bias)
else:
linear = FastLinear(weight, bias)
elif quantize == "eetq":
try:
from text_generation_server.layers.eetq import EETQLinear
linear = EETQLinear(weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
from text_generation_server.layers.fp8 import get_fp8_linear
linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes":
try:
from text_generation_server.layers.bnb import (
warn_deprecate_bnb,
Linear8bitLt,
)
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
warn_deprecate_bnb()
linear = Linear8bitLt(
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlinLinear,
GPTQMarlinWeight,
)
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
)
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(weight, bias)
else: else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear return FastLinear(weight, bias)
linear = QuantLinear(
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
weight.bits,
weight.groupsize,
)
else:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
elif quantize == "awq": return weight.get_linear(bias)
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear(
w_bit=weight.bits,
group_size=weight.groupsize,
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
MarlinLinear,
MarlinWeight,
)
if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias)
else:
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
...@@ -7,7 +7,7 @@ from loguru import logger ...@@ -7,7 +7,7 @@ from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try: try:
import marlin_kernels import marlin_kernels
...@@ -63,8 +63,7 @@ class MarlinWeightsLoader(WeightsLoader): ...@@ -63,8 +63,7 @@ class MarlinWeightsLoader(WeightsLoader):
return weight return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if self.is_marlin_24:
if is_marlin_24:
try: try:
B = torch.cat( B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
...@@ -101,8 +100,7 @@ class MarlinWeightsLoader(WeightsLoader): ...@@ -101,8 +100,7 @@ class MarlinWeightsLoader(WeightsLoader):
return weight return weight
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" if self.is_marlin_24:
if is_marlin_24:
try: try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0) B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError: except RuntimeError:
...@@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor): ...@@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor):
@dataclass @dataclass
class GPTQMarlinWeight: class GPTQMarlinWeight(Weight):
""" """
Repacked GPTQ Marlin weights. Repacked GPTQ Marlin weights.
""" """
...@@ -219,6 +217,12 @@ class GPTQMarlinWeight: ...@@ -219,6 +217,12 @@ class GPTQMarlinWeight:
assert self.g_idx.dtype == torch.int32 assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32 assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin( def repack_gptq_for_marlin(
*, *,
...@@ -376,6 +380,12 @@ class GPTQMarlin24Weight: ...@@ -376,6 +380,12 @@ class GPTQMarlin24Weight:
assert self.B_meta.dtype == torch.int16 assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16 assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module): class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
...@@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): ...@@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
@dataclass @dataclass
class MarlinWeight: class MarlinWeight(Weight):
""" """
Marlin weights. Marlin weights.
...@@ -583,6 +593,9 @@ class MarlinWeight: ...@@ -583,6 +593,9 @@ class MarlinWeight:
assert self.B.dtype == torch.int32 assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16 assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module): class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
......
...@@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer): ...@@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer):
quantize = config.quantize quantize = config.quantize
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize), get_linear(weight, bias=None),
process_group=weights.process_group, process_group=weights.process_group,
should_gather=should_gather, should_gather=should_gather,
) )
...@@ -134,7 +134,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -134,7 +134,7 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_gate_up only implemented without bias") raise NotImplementedError("packed_gate_up only implemented without bias")
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@classmethod @classmethod
...@@ -157,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -157,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_qkv only implemented for baichuan") raise NotImplementedError("packed_qkv only implemented for baichuan")
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@classmethod @classmethod
...@@ -167,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -167,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
@classmethod @classmethod
...@@ -177,7 +177,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -177,7 +177,7 @@ class TensorParallelColumnLinear(SuperLayer):
for prefix in prefixes: for prefix in prefixes:
weight = weights.get_weights_col(prefix) weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize)) linears.append(get_linear(weight, b))
linear = LayerConcat(linears) linear = LayerConcat(linears)
else: else:
weight = weights.get_multi_weights_col(prefixes, dim=dim) weight = weights.get_multi_weights_col(prefixes, dim=dim)
...@@ -186,7 +186,7 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -186,7 +186,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = torch.cat(b, dim=dim) bias = torch.cat(b, dim=dim)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias)
return cls(linear) return cls(linear)
...@@ -205,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer): ...@@ -205,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer):
else: else:
bias = None bias = None
return cls( return cls(
get_linear(weight, bias, config.quantize), get_linear(weight, bias),
process_group=weights.process_group, process_group=weights.process_group,
) )
......
...@@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights):
else: else:
bias = None bias = None
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=bias))
get_linear(weight, bias=bias, quantize=config.quantize)
)
class FlashCohereAttention(torch.nn.Module): class FlashCohereAttention(torch.nn.Module):
......
...@@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls): ...@@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls):
if cls == TensorParallelRowLinear: if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous() expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None, config.quantize) linear = get_linear(expert_slice, None)
experts.append(cls(linear, weights.process_group)) experts.append(cls(linear, weights.process_group))
else: else:
linear = get_linear(expert_slice, None, config.quantize) linear = get_linear(expert_slice, None)
experts.append(cls(linear)) experts.append(cls(linear))
return experts return experts
......
...@@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=None))
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemma2Attention(torch.nn.Module): class FlashGemma2Attention(torch.nn.Module):
......
...@@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights): ...@@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear( return TensorParallelColumnLinear(get_linear(weight, bias=None))
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
......
...@@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights): ...@@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights):
bias = torch.cat(tensors, dim=0) bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device) bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def _load_qkv(config, prefix: str, weights, head_size, num_heads): def _load_qkv(config, prefix: str, weights, head_size, num_heads):
...@@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads): ...@@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
3 * num_heads * head_size 3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}" ], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool): def load_row(config, prefix: str, weights, bias: bool):
...@@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool): ...@@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = None bias = None
return TensorParallelRowLinear( return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group get_linear(weight, bias), process_group=weights.process_group
) )
...@@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool): ...@@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool):
else: else:
bias = None bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias))
class FlashGPT2Attention(torch.nn.Module): class FlashGPT2Attention(torch.nn.Module):
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -25,7 +26,6 @@ import torch.distributed ...@@ -25,7 +26,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
...@@ -42,10 +42,16 @@ from text_generation_server.layers import ( ...@@ -42,10 +42,16 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
...@@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id): ...@@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
) )
@contextmanager
def no_fp8(weights: Weights):
weights_loader = weights.weights_loader
if (
isinstance(weights_loader, DefaultWeightsLoader)
and weights_loader.weight_class is Fp8Weight
):
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
...@@ -330,12 +349,15 @@ class LlamaMLP(nn.Module): ...@@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
with no_fp8(weights):
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
index=index, index=index,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
config=config, config=config,
weights=weights, weights=weights,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
) )
...@@ -470,9 +492,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -470,9 +492,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" "model.embed_tokens"
if not prefix
else f"{prefix}.model.embed_tokens"
), ),
weights=weights, weights=weights,
) )
...@@ -482,6 +507,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -482,6 +507,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else: else:
suffix = "lm_head" suffix = "lm_head"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.{suffix}", prefix=suffix if not prefix else f"{prefix}.{suffix}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment