Unverified Commit ba291dad authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Improve the handling of quantized weights (#2250)

* Improve the handling of quantized weights

Handling of quantized weights was split between two mechanisms:

- For quantized checkpoints, we used the new weight loader
  infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
  instead relied on conditional in `get_linear`.

Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.

This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:

- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
  `get_linear` does not need to know how to handle quantizer linear
  layers.
- All quantizer weights are strongly typed, we don't pass around
  raw tensors.
- We don't have to pass around the `quantizer` string everywhere.

* Exclude non-MLP layers when using FP8 quantization with Llama
parent 1d1b1efa
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -0.6645508,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -2.2324219,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 6088,
"logprob": -1.6074219,
"special": false,
"text": " parse"
},
{
"id": 1243,
"logprob": -1.6298828,
"special": false,
"text": " test"
},
{
"id": 1206,
"logprob": -0.72558594,
"special": false,
"text": " case"
},
{
"id": 1024,
"logprob": -0.40429688,
"special": false,
"text": " name"
},
{
"id": 515,
"logprob": 0.0,
"special": false,
"text": " from"
},
{
"id": 525,
"logprob": -1.2519531,
"special": false,
"text": " '"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not parse test case name from '"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.0859375,
"text": "Test"
},
{
"id": 2009,
"logprob": -16.359375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 5229,
"logprob": -2.7988281,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": -0.91259766,
"special": false,
"text": ":"
},
{
"id": 853,
"logprob": -2.8496094,
"special": false,
"text": " Un"
},
{
"id": 23765,
"logprob": -1.1894531,
"special": false,
"text": "supported"
},
{
"id": 4714,
"logprob": -1.5917969,
"special": false,
"text": " browser"
},
{
"id": 29892,
"logprob": -0.34765625,
"special": false,
"text": ","
},
{
"id": 1873,
"logprob": -1.2695312,
"special": false,
"text": " version"
},
{
"id": 470,
"logprob": -0.25170898,
"special": false,
"text": " or"
},
{
"id": 7481,
"logprob": -0.21411133,
"special": false,
"text": " platform"
},
{
"id": 13,
"logprob": -1.1162109,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": " failed: Unsupported browser, version or platform\n"
}
]
......@@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
......
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin24_handle(launcher):
with launcher(
"nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin24_handle):
await flash_llama_marlin24_handle.health(300)
return flash_llama_marlin24_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin24_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
......@@ -2,6 +2,7 @@ import pytest
import torch
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
......@@ -363,7 +364,10 @@ class MockWeights(Weights):
self.process_group = process_group
self.prefix = prefix
self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader
# We don't need to get linear layers, so just wrap raw tensors.
DefaultWeightsLoader(lambda x: x)
if weights_loader is None
else weights_loader
)
self._handles = {}
......@@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
......@@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
......@@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
......@@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
......@@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
......@@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
......@@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
g_idx=None,
bits=8.0,
groupsize=2.0,
use_awq_kernel=True,
use_exllama=False,
)
......@@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq):
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......@@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_awq_kernel=False,
use_exllama=False,
)
......@@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader):
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
......
import torch
from loguru import logger
from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Int8Params, Params4bit
from loguru import logger
from text_generation_server.utils.weights import Weight
@lru_cache(1)
......@@ -12,6 +15,14 @@ def warn_deprecate_bnb():
)
@dataclass
class BNBWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
class Linear8bitLt(torch.nn.Module):
def __init__(
self,
......@@ -70,6 +81,22 @@ class Linear8bitLt(torch.nn.Module):
return out
@dataclass
class BNBFP4Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="fp4")
@dataclass
class BNBNF4Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="nf4")
class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
......
from dataclasses import dataclass
import torch
from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import Weight
@dataclass
class EETQWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module):
......
import torch
from typing import List, Union
from dataclasses import dataclass
from typing import List, Union
from text_generation_server.utils.weights import WeightsLoader, Weights
import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class Exl2Weight:
class Exl2Weight(Weight):
"""
Exllama2 exl2 quantized weights.
"""
......@@ -25,6 +25,11 @@ class Exl2Weight:
def device(self) -> torch.device:
return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
......
from enum import Enum, auto
from dataclasses import dataclass
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight
def get_fp8_linear() -> torch.nn.Module:
......@@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
return qweight, scale
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return get_fp8_linear()(self.weight, bias)
class Fp8Linear(torch.nn.Module):
def __init__(
self,
......
from dataclasses import dataclass
from loguru import logger
import os
from dataclasses import dataclass
from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class GPTQWeight:
class GPTQWeight(Weight):
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_awq_kernel: bool
use_exllama: bool
def __post_init__(self):
......@@ -29,6 +28,50 @@ class GPTQWeight:
def device(self) -> torch.device:
return self.qweight.device
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
return WQLinear(
w_bit=self.bits,
group_size=self.groupsize,
qweight=self.qweight,
qzeros=self.qzeros,
scales=self.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
bias,
self.bits,
self.groupsize,
)
try:
major, _minor = torch.cuda.get_device_capability()
......@@ -45,6 +88,8 @@ elif CAN_EXLLAMA:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers,
set_device,
)
......@@ -53,6 +98,8 @@ elif CAN_EXLLAMA:
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers,
set_device,
)
......@@ -162,6 +209,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=False,
)
......@@ -255,6 +303,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
......@@ -336,8 +385,8 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight,
)
......@@ -389,6 +438,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
......
from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
if SYSTEM == "rocm":
try:
......@@ -90,167 +91,14 @@ class FastLinearROCm(torch.nn.Module):
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias, quantize):
if quantize is None:
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
if SYSTEM == "rocm":
linear = FastLinearROCm(weight, bias)
else:
linear = FastLinear(weight, bias)
elif quantize == "eetq":
try:
from text_generation_server.layers.eetq import EETQLinear
linear = EETQLinear(weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
from text_generation_server.layers.fp8 import get_fp8_linear
linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes":
try:
from text_generation_server.layers.bnb import (
warn_deprecate_bnb,
Linear8bitLt,
)
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
warn_deprecate_bnb()
linear = Linear8bitLt(
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
try:
from text_generation_server.layers.bnb import Linear4bit
except ImportError:
raise NotImplementedError(
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
)
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlinLinear,
GPTQMarlinWeight,
)
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
)
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(weight, bias)
return FastLinearROCm(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
weight.bits,
weight.groupsize,
)
else:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
return FastLinear(weight, bias)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear(
w_bit=weight.bits,
group_size=weight.groupsize,
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Linear,
GPTQMarlin24Weight,
MarlinLinear,
MarlinWeight,
)
if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elif isinstance(weight, MarlinWeight):
linear = MarlinLinear(weight=weight, bias=bias)
else:
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
return weight.get_linear(bias)
......@@ -7,7 +7,7 @@ from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
......@@ -63,8 +63,7 @@ class MarlinWeightsLoader(WeightsLoader):
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
......@@ -101,8 +100,7 @@ class MarlinWeightsLoader(WeightsLoader):
return weight
def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
......@@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor):
@dataclass
class GPTQMarlinWeight:
class GPTQMarlinWeight(Weight):
"""
Repacked GPTQ Marlin weights.
"""
......@@ -219,6 +217,12 @@ class GPTQMarlinWeight:
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin(
*,
......@@ -376,6 +380,12 @@ class GPTQMarlin24Weight:
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
......@@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
@dataclass
class MarlinWeight:
class MarlinWeight(Weight):
"""
Marlin weights.
......@@ -583,6 +593,9 @@ class MarlinWeight:
assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
......
......@@ -77,7 +77,7 @@ class TensorParallelHead(SuperLayer):
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
get_linear(weight, bias=None),
process_group=weights.process_group,
should_gather=should_gather,
)
......@@ -134,7 +134,7 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
......@@ -157,7 +157,7 @@ class TensorParallelColumnLinear(SuperLayer):
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
......@@ -167,7 +167,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
......@@ -177,7 +177,7 @@ class TensorParallelColumnLinear(SuperLayer):
for prefix in prefixes:
weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize))
linears.append(get_linear(weight, b))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(prefixes, dim=dim)
......@@ -186,7 +186,7 @@ class TensorParallelColumnLinear(SuperLayer):
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
linear = get_linear(weight, bias)
return cls(linear)
......@@ -205,7 +205,7 @@ class TensorParallelRowLinear(SuperLayer):
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
get_linear(weight, bias),
process_group=weights.process_group,
)
......
......@@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights):
else:
bias = None
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
class FlashCohereAttention(torch.nn.Module):
......
......@@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls):
if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None, config.quantize)
linear = get_linear(expert_slice, None)
experts.append(cls(linear, weights.process_group))
else:
linear = get_linear(expert_slice, None, config.quantize)
linear = get_linear(expert_slice, None)
experts.append(cls(linear))
return experts
......
......@@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma2Attention(torch.nn.Module):
......
......@@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemmaAttention(torch.nn.Module):
......
......@@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights):
bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
......@@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads):
3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool):
......@@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
get_linear(weight, bias), process_group=weights.process_group
)
......@@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool):
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias))
class FlashGPT2Attention(torch.nn.Module):
......
......@@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple
import torch
......@@ -25,7 +26,6 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
......@@ -42,10 +42,16 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm":
try:
......@@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
)
@contextmanager
def no_fp8(weights: Weights):
weights_loader = weights.weights_loader
if (
isinstance(weights_loader, DefaultWeightsLoader)
and weights_loader.weight_class is Fp8Weight
):
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
......@@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights):
super().__init__()
with no_fp8(weights):
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
......@@ -470,9 +492,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
"model.embed_tokens"
if not prefix
else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
......@@ -482,6 +507,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else:
suffix = "lm_head"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment