Commit 36dd1601 authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for exl2 quantization

Mostly straightforward, changes to existing code:

* Wrap quantizer parameters in a small wrapper to avoid passing
  around untyped tuples and needing to repack them as a dict.
* Move scratch space computation to warmup, because we need the
  maximum input sequence length to avoid allocating huge
  scratch buffers that OOM.
parent cbced7f0
......@@ -62,6 +62,7 @@ Options:
Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
......
......@@ -2,7 +2,6 @@
## What is Guidance?
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
## How is it used?
......
......@@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
ignore_logprob = False
def serialize(
self,
......@@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension):
return (
token.id == other.id
and token.text == other.text
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
and (
self.ignore_logprob
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
)
and token.special == other.special
)
......@@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id
and prefill_token.text == other.text
and (
math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
self.ignore_logprob
or math.isclose(
prefill_token.logprob,
other.logprob,
rel_tol=self.rtol,
)
if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob
......@@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator):
rtol = 0.75
class IgnoreLogProbResponseComparator(ResponseComparator):
ignore_logprob = True
class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}")
......@@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator)
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
return snapshot.use_extension(IgnoreLogProbResponseComparator)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9316406,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5136719,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.7783203,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2314453,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -2.0019531,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.5009766,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.057434082,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4912109,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2636719,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.9980469,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15795898,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -1.0458984,
"special": false,
"text": " server"
},
{
"id": 31680,
"logprob": -1.3623047,
"special": false,
"text": " responds"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " with"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 330,
"logprob": -0.5678711,
"special": false,
"text": " \""
},
{
"id": 1049,
"logprob": -0.12322998,
"special": false,
"text": "200"
},
{
"id": 10619,
"logprob": 0.0,
"special": false,
"text": " OK"
},
{
"id": 1,
"logprob": 0.0,
"special": false,
"text": "\""
}
],
"top_tokens": null
},
"generated_text": "Test request. The server responds with a \"200 OK\""
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9785156,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4941406,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.79345703,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2324219,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9794922,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.058258057,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4892578,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2783203,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3945312,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.40625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9433594,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4726562,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8022461,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2509766,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4677734,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059173584,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4990234,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2822266,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3867188,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.421875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9511719,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.46875,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.77490234,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2558594,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4990234,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059143066,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4941406,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2578125,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3964844,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4140625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9101562,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5039062,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8076172,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2236328,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9853516,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.056671143,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.5107422,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2597656,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",
revision="2.5bpw",
# Set max input length to avoid OOM due to extremely large
# scratch buffer.
max_input_length=1024,
num_shard=1,
quantize="exl2",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_exl2(flash_llama_exl2_handle):
await flash_llama_exl2_handle.health(300)
return flash_llama_exl2_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
response = await flash_llama_exl2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
flash_llama_exl2, ignore_logprob_response_snapshot
):
response = await flash_llama_exl2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert (
response.generated_text == 'Test request. The server responds with a "200 OK"'
)
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(
flash_llama_exl2, generate_load, ignore_logprob_response_snapshot
):
responses = await generate_load(
flash_llama_exl2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == ignore_logprob_response_snapshot
......@@ -55,6 +55,10 @@ enum Quantization {
/// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq,
/// Variable bit quantization. Requires a specific EXL2 quantized model:
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
/// not support tensor parallelism (num_shard > 1).
Exl2,
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not.
......@@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization {
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Exl2 => {
write!(f, "exl2")
}
Quantization::Gptq => {
write!(f, "gptq")
}
......@@ -1461,6 +1468,11 @@ fn main() -> Result<(), LauncherError> {
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes");
}
......
......@@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"
......
import torch
from dataclasses import dataclass
@dataclass
class Exl2Weight:
"""
Exllama2 exl2 quantized weights.
"""
q_weight: torch.Tensor
q_scale: torch.Tensor
q_invperm: torch.Tensor
q_scale_max: torch.Tensor
q_groups: torch.Tensor
def __post_init__(self):
self.q_scale_max /= 256
self.q_invperm = self.q_invperm.short()
@property
def device(self) -> torch.device:
return self.q_weight.device
from dataclasses import dataclass
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
@dataclass
class GPTQWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_exllama: bool
def __post_init__(self):
if self.scales.dtype == torch.float:
self.scales = self.scales.half()
@property
def device(self) -> torch.device:
return self.qweight.device
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
......
from text_generation_server.utils.weights import GPTQWeight
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
......@@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int):
class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
def __init__(self, weight: GPTQWeight, bias):
super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4
assert weight.bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.device = weight.qweight.device
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and (
(self.g_idx == 0).all()
or torch.equal(
g_idx.cpu(),
weight.g_idx.cpu(),
torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
dtype=torch.int32,
),
)
):
......@@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module):
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
self.height = weight.qweight.shape[0] * 8
self.width = weight.qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
......@@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module):
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
assert weight.groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
......
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError:
......@@ -15,6 +20,15 @@ except ImportError:
none_tensor = torch.empty((1, 1), device="meta")
@dataclass
class _ExtraTensors:
"""Additional generated quantizer tensors."""
q_group_map: Optional[torch.Tensor] = None
q_invperm: Optional[torch.Tensor] = None
q_perm: Optional[torch.Tensor] = None
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
......@@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape)
# Group map needed for irregular group sizes
def make_group_map(q_groups, num_qrows):
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
......@@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows):
# Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
def ext_make_q_matrix(
w: Exl2Weight | GPTQWeight,
extra: _ExtraTensors,
temp_dq,
key: Optional[str] = None,
):
"""
Create Q matrix
"""
# EXL2
# won't work as the moment because the tensors are not the same.
if "q_weight" in w:
w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
if isinstance(w, Exl2Weight):
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
extra.q_perm = torch.argsort(w.q_invperm).short()
return make_q_matrix(
w["q_weight"],
w["q_perm"],
w["q_invperm"],
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
w["q_group_map"],
w.q_weight,
extra.q_perm,
w.q_invperm,
w.q_scale,
w.q_scale_max,
w.q_groups,
extra.q_group_map,
none_tensor,
none_tensor,
none_tensor,
temp_dq,
)
# GPTQ
elif "qweight" in w:
if w["scales"].dtype == torch.float:
w["scales"] = w["scales"].half()
elif isinstance(w, GPTQWeight):
if w.scales.dtype == torch.float:
w.scales = w.scales.half()
# GPTQ with g_idx (act_order)
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty(
(w["qweight"].shape[0] * 8,),
if w.g_idx is not None and not (w.g_idx == 0).all().item():
extra.q_perm = torch.empty(
(w.qweight.shape[0] * 8,),
dtype=torch.short,
device=w["qweight"].device,
device=w.qweight.device,
)
w["q_invperm"] = torch.empty_like(w["q_perm"])
extra.q_invperm = torch.empty_like(extra.q_perm)
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(
w["qweight"],
w["q_perm"],
w["q_invperm"],
w.qweight,
extra.q_perm,
extra.q_invperm,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
w.qzeros,
w.scales,
w.g_idx.cpu(),
temp_dq,
)
# GPTQ without g_idx
else:
return make_q_matrix(
w["qweight"],
w.qweight,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
w["qzeros"],
w["scales"],
w.qzeros,
w.scales,
none_tensor,
temp_dq,
)
......@@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
DEVICE = None
FIXED_BYTES = 0
LAYERS = []
......@@ -134,8 +143,13 @@ def set_device(device):
def create_exllama_buffers(max_total_tokens: int):
global FIXED_BYTES, LAYERS, DEVICE
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
global LAYERS, DEVICE
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
for layer in LAYERS:
layer.post_init(temp_dq)
......@@ -146,49 +160,48 @@ class QuantLinear(nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
def __init__(
self,
weight: Exl2Weight | GPTQWeight,
bias: torch.Tensor,
):
super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
)
self.q_handle = None
self.q_tensors = None
self.bits = bits
self.maxq = 2**self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1]
self.q_tensors = weight
self.extra_tensors = _ExtraTensors()
if isinstance(weight, Exl2Weight):
self.infeatures = weight.q_invperm.shape[0]
self.outfeatures = weight.q_weight.shape[1]
elif isinstance(weight, GPTQWeight):
if weight.bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
)
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
self.outfeatures = weight.qweight.shape[1]
self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx
self.device = weight.device
self.bias = bias if bias is not None else None
self.group_size = groupsize
global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
global LAYERS
LAYERS.append(self)
def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.q_tensors = {
"qweight": self.qweight,
"qzeros": self.qzeros,
"scales": self.scales,
"g_idx": self.g_idx,
}
device = self.q_tensors.device
assert device.type == "cuda"
assert device.index is not None
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
......
from typing import Optional
import torch
from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
if SYSTEM == "rocm":
try:
......@@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize):
bias,
quant_type="nf4",
)
elif quantize == "exl2":
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception:
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
if use_exllama:
if weight.use_exllama:
try:
from text_generation_server.layers.gptq import (
ExllamaQuantLinear,
......@@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize):
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
linear = ExllamaQuantLinear(
qweight, qzeros, scales, g_idx, bias, bits, groupsize
)
linear = ExllamaQuantLinear(weight, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
weight.qweight,
weight.qzeros,
weight.scales,
weight.g_idx,
bias,
bits,
groupsize,
weight.bits,
weight.groupsize,
)
elif quantize == "awq":
try:
qweight, qzeros, scales, _, bits, groupsize, _ = weight
except Exception:
if not isinstance(weight, GPTQWeight):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
......@@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize):
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear(
w_bit=bits,
group_size=groupsize,
qweight=qweight,
qzeros=qzeros,
scales=scales,
w_bit=weight.bits,
group_size=weight.groupsize,
qweight=weight.qweight,
qzeros=weight.qzeros,
scales=weight.scales,
bias=bias is not None,
)
except ImportError:
......
import torch
from torch.nn import functional as F
from typing import List
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
class LayerConcat(torch.nn.Module):
"""
Apply multiple layers to the input and concatenate their
outputs.
"""
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
"""
`dim` is the dimension along which layer outputs are concatenated.
"""
super().__init__()
self.layers = layers
self.dim = dim
def forward(self, x: torch.Tensor):
outputs = [layer(x) for layer in self.layers]
return torch.cat(outputs, self.dim)
class SuperLayer(torch.nn.Module):
......@@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer):
@staticmethod
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
if config.quantize == "exl2":
try:
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize)
should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
......@@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer):
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else:
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
......@@ -108,22 +141,35 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
return cls.load_multi(config, [prefix], weights, bias, dim=0)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
weight = weights.get_weights_col(prefix, config.quantize)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
if config.quantize == "exl2":
linears = []
for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
......
......@@ -263,7 +263,7 @@ def get_model(
trust_remote_code: bool,
) -> Model:
if dtype is None:
if quantize in ["awq", "gptq"]:
if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params.
dtype = torch.float16
else:
......@@ -402,12 +402,17 @@ def get_model(
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq"}:
if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}")
quantize = method
else:
logger.info(f"Unknown quantization method {method}")
if quantize == "exl2" and sharded:
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)
if model_type == MAMBA:
return Mamba(
model_id,
......@@ -881,6 +886,8 @@ def get_model(
raise NotImplementedError("4bit quantization is not supported for AutoModel")
elif quantize == "eetq":
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,
......
......@@ -21,6 +21,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
......@@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights):
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
......
......@@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.suffix",
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
......
......@@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig):
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class MistralAttention(torch.nn.Module):
def __init__(
self,
......@@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
......
......@@ -5,6 +5,7 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -90,8 +91,15 @@ def _load_multi_mqa_gptq(
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=HAS_EXLLAMA,
)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment