Commit 36dd1601 authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for exl2 quantization

Mostly straightforward, changes to existing code:

* Wrap quantizer parameters in a small wrapper to avoid passing
  around untyped tuples and needing to repack them as a dict.
* Move scratch space computation to warmup, because we need the
  maximum input sequence length to avoid allocating huge
  scratch buffers that OOM.
parent cbced7f0
...@@ -62,6 +62,7 @@ Options: ...@@ -62,6 +62,7 @@ Options:
Possible values: Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency - awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git> - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
## What is Guidance? ## What is Guidance?
Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON. Guidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.
## How is it used? ## How is it used?
......
...@@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") ...@@ -38,6 +38,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
ignore_logprob = False
def serialize( def serialize(
self, self,
...@@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension): ...@@ -95,7 +96,10 @@ class ResponseComparator(JSONSnapshotExtension):
return ( return (
token.id == other.id token.id == other.id
and token.text == other.text and token.text == other.text
and math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) and (
self.ignore_logprob
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
)
and token.special == other.special and token.special == other.special
) )
...@@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension): ...@@ -105,8 +109,11 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id prefill_token.id == other.id
and prefill_token.text == other.text and prefill_token.text == other.text
and ( and (
math.isclose( self.ignore_logprob
prefill_token.logprob, other.logprob, rel_tol=self.rtol or math.isclose(
prefill_token.logprob,
other.logprob,
rel_tol=self.rtol,
) )
if prefill_token.logprob is not None if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob else prefill_token.logprob == other.logprob
...@@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator): ...@@ -223,6 +230,10 @@ class GenerousResponseComparator(ResponseComparator):
rtol = 0.75 rtol = 0.75
class IgnoreLogProbResponseComparator(ResponseComparator):
ignore_logprob = True
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}")
...@@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot): ...@@ -274,6 +285,11 @@ def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator) return snapshot.use_extension(GenerousResponseComparator)
@pytest.fixture
def ignore_logprob_response_snapshot(snapshot):
return snapshot.use_extension(IgnoreLogProbResponseComparator)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def event_loop(): def event_loop():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9316406,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5136719,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.7783203,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2314453,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -2.0019531,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.5009766,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.057434082,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4912109,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2636719,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -1.9980469,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15795898,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -1.0458984,
"special": false,
"text": " server"
},
{
"id": 31680,
"logprob": -1.3623047,
"special": false,
"text": " responds"
},
{
"id": 449,
"logprob": 0.0,
"special": false,
"text": " with"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 330,
"logprob": -0.5678711,
"special": false,
"text": " \""
},
{
"id": 1049,
"logprob": -0.12322998,
"special": false,
"text": "200"
},
{
"id": 10619,
"logprob": 0.0,
"special": false,
"text": " OK"
},
{
"id": 1,
"logprob": 0.0,
"special": false,
"text": "\""
}
],
"top_tokens": null
},
"generated_text": "Test request. The server responds with a \"200 OK\""
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.453125,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9785156,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4941406,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.79345703,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2324219,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9794922,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.058258057,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4892578,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2783203,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3945312,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.40625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9433594,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.4726562,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8022461,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2509766,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4677734,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059173584,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4990234,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2822266,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3867188,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.421875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9511719,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.46875,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.77490234,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2558594,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.984375,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4990234,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.059143066,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.4941406,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2578125,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.3964844,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.4140625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.9101562,
"special": false,
"text": ":"
},
{
"id": 330,
"logprob": -3.5039062,
"special": false,
"text": " \""
},
{
"id": 489,
"logprob": -0.8076172,
"special": false,
"text": " +"
},
{
"id": 1715,
"logprob": -1.2236328,
"special": false,
"text": " request"
},
{
"id": 489,
"logprob": -1.9853516,
"special": false,
"text": " +"
},
{
"id": 2990,
"logprob": -1.4892578,
"special": false,
"text": " \"\\"
},
{
"id": 77,
"logprob": -0.056671143,
"special": false,
"text": "n"
},
{
"id": 702,
"logprob": -1.5107422,
"special": false,
"text": "\"\n"
},
{
"id": 262,
"logprob": -1.2597656,
"special": false,
"text": " "
},
{
"id": 557,
"logprob": -2.4042969,
"special": false,
"text": " }\n\n"
}
],
"top_tokens": null
},
"generated_text": ": \" + request + \"\\n\"\n }\n\n"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",
revision="2.5bpw",
# Set max input length to avoid OOM due to extremely large
# scratch buffer.
max_input_length=1024,
num_shard=1,
quantize="exl2",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_exl2(flash_llama_exl2_handle):
await flash_llama_exl2_handle.health(300)
return flash_llama_exl2_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
response = await flash_llama_exl2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
flash_llama_exl2, ignore_logprob_response_snapshot
):
response = await flash_llama_exl2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert (
response.generated_text == 'Test request. The server responds with a "200 OK"'
)
assert response == ignore_logprob_response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(
flash_llama_exl2, generate_load, ignore_logprob_response_snapshot
):
responses = await generate_load(
flash_llama_exl2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == ignore_logprob_response_snapshot
...@@ -55,6 +55,10 @@ enum Quantization { ...@@ -55,6 +55,10 @@ enum Quantization {
/// Should be a drop-in replacement to bitsandbytes with much better performance. /// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git> /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq, Eetq,
/// Variable bit quantization. Requires a specific EXL2 quantized model:
/// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does
/// not support tensor parallelism (num_shard > 1).
Exl2,
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not. /// triton kernel (wider support) when it's not.
...@@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization { ...@@ -95,6 +99,9 @@ impl std::fmt::Display for Quantization {
Quantization::BitsandbytesFP4 => { Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4") write!(f, "bitsandbytes-fp4")
} }
Quantization::Exl2 => {
write!(f, "exl2")
}
Quantization::Gptq => { Quantization::Gptq => {
write!(f, "gptq") write!(f, "gptq")
} }
...@@ -1461,6 +1468,11 @@ fn main() -> Result<(), LauncherError> { ...@@ -1461,6 +1468,11 @@ fn main() -> Result<(), LauncherError> {
let num_shard = find_num_shards(args.sharded, args.num_shard)?; let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 { if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes"); tracing::info!("Sharding model on {num_shard} processes");
} }
......
...@@ -19,6 +19,7 @@ class Quantization(str, Enum): ...@@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq" gptq = "gptq"
awq = "awq" awq = "awq"
eetq = "eetq" eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8" fp8 = "fp8"
......
import torch
from dataclasses import dataclass
@dataclass
class Exl2Weight:
"""
Exllama2 exl2 quantized weights.
"""
q_weight: torch.Tensor
q_scale: torch.Tensor
q_invperm: torch.Tensor
q_scale_max: torch.Tensor
q_groups: torch.Tensor
def __post_init__(self):
self.q_scale_max /= 256
self.q_invperm = self.q_invperm.short()
@property
def device(self) -> torch.device:
return self.q_weight.device
from dataclasses import dataclass
import os import os
from typing import Optional
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
SYSTEM, SYSTEM,
) )
@dataclass
class GPTQWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_exllama: bool
def __post_init__(self):
if self.scales.dtype == torch.float:
self.scales = self.scales.half()
@property
def device(self) -> torch.device:
return self.qweight.device
try: try:
major, _minor = torch.cuda.get_device_capability() major, _minor = torch.cuda.get_device_capability()
except Exception: except Exception:
......
from text_generation_server.utils.weights import GPTQWeight
import torch import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
...@@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int): ...@@ -65,24 +66,25 @@ def create_exllama_buffers(max_total_tokens: int):
class Ex4bitLinear(torch.nn.Module): class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights""" """Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): def __init__(self, weight: GPTQWeight, bias):
super().__init__() super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4 assert weight.bits == 4
self.device = qweight.device self.device = weight.qweight.device
self.qweight = qweight self.qweight = weight.qweight
self.qzeros = qzeros self.qzeros = weight.qzeros
self.scales = scales self.scales = weight.scales
self.g_idx = g_idx.cpu() if g_idx is not None else None self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
if self.g_idx is not None and ( if self.g_idx is not None and (
(self.g_idx == 0).all() (self.g_idx == 0).all()
or torch.equal( or torch.equal(
g_idx.cpu(), weight.g_idx.cpu(),
torch.tensor( torch.tensor(
[i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32 [i // weight.groupsize for i in range(weight.g_idx.shape[0])],
dtype=torch.int32,
), ),
) )
): ):
...@@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module): ...@@ -96,8 +98,8 @@ class Ex4bitLinear(torch.nn.Module):
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
) )
self.height = qweight.shape[0] * 8 self.height = weight.qweight.shape[0] * 8
self.width = qweight.shape[1] self.width = weight.qweight.shape[1]
# Infer groupsize from height of qzeros # Infer groupsize from height of qzeros
self.groupsize = None self.groupsize = None
...@@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module): ...@@ -105,7 +107,7 @@ class Ex4bitLinear(torch.nn.Module):
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None: if self.groupsize is not None:
assert groupsize == self.groupsize assert weight.groupsize == self.groupsize
# Handle act-order matrix # Handle act-order matrix
if self.g_idx is not None: if self.g_idx is not None:
......
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from dataclasses import dataclass
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
try: try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half from exllamav2_kernels import make_q_matrix, gemm_half_q_half
except ImportError: except ImportError:
...@@ -15,6 +20,15 @@ except ImportError: ...@@ -15,6 +20,15 @@ except ImportError:
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
@dataclass
class _ExtraTensors:
"""Additional generated quantizer tensors."""
q_group_map: Optional[torch.Tensor] = None
q_invperm: Optional[torch.Tensor] = None
q_perm: Optional[torch.Tensor] = None
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4""" """Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,) output_shape = x.shape[:-1] + (q4_width,)
...@@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): ...@@ -24,11 +38,7 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape) return output.view(output_shape)
# Group map needed for irregular group sizes def make_group_map(q_groups: torch.Tensor, num_qrows: int):
def make_group_map(q_groups, num_qrows):
gr = q_groups.tolist() gr = q_groups.tolist()
group_map = [] group_map = []
num_groups = len(gr) // 2 num_groups = len(gr) // 2
...@@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows): ...@@ -50,72 +60,72 @@ def make_group_map(q_groups, num_qrows):
# Create Q matrix # Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None): def ext_make_q_matrix(
w: Exl2Weight | GPTQWeight,
extra: _ExtraTensors,
temp_dq,
key: Optional[str] = None,
):
""" """
Create Q matrix Create Q matrix
""" """
# EXL2 # EXL2
# won't work as the moment because the tensors are not the same. if isinstance(w, Exl2Weight):
if "q_weight" in w: extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
w["q_scale_max"] /= 256 extra.q_perm = torch.argsort(w.q_invperm).short()
w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
return make_q_matrix( return make_q_matrix(
w["q_weight"], w.q_weight,
w["q_perm"], extra.q_perm,
w["q_invperm"], w.q_invperm,
w["q_scale"], w.q_scale,
w["q_scale_max"], w.q_scale_max,
w["q_groups"], w.q_groups,
w["q_group_map"], extra.q_group_map,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
temp_dq, temp_dq,
) )
# GPTQ # GPTQ
elif "qweight" in w: elif isinstance(w, GPTQWeight):
if w["scales"].dtype == torch.float: if w.scales.dtype == torch.float:
w["scales"] = w["scales"].half() w.scales = w.scales.half()
# GPTQ with g_idx (act_order) # GPTQ with g_idx (act_order)
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): if w.g_idx is not None and not (w.g_idx == 0).all().item():
w["q_perm"] = torch.empty( extra.q_perm = torch.empty(
(w["qweight"].shape[0] * 8,), (w.qweight.shape[0] * 8,),
dtype=torch.short, dtype=torch.short,
device=w["qweight"].device, device=w.qweight.device,
) )
w["q_invperm"] = torch.empty_like(w["q_perm"]) extra.q_invperm = torch.empty_like(extra.q_perm)
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix( return make_q_matrix(
w["qweight"], w.qweight,
w["q_perm"], extra.q_perm,
w["q_invperm"], extra.q_invperm,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
w["qzeros"], w.qzeros,
w["scales"], w.scales,
w["g_idx"].cpu(), w.g_idx.cpu(),
temp_dq, temp_dq,
) )
# GPTQ without g_idx # GPTQ without g_idx
else: else:
return make_q_matrix( return make_q_matrix(
w["qweight"], w.qweight,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
w["qzeros"], w.qzeros,
w["scales"], w.scales,
none_tensor, none_tensor,
temp_dq, temp_dq,
) )
...@@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): ...@@ -124,7 +134,6 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
DEVICE = None DEVICE = None
FIXED_BYTES = 0
LAYERS = [] LAYERS = []
...@@ -134,8 +143,13 @@ def set_device(device): ...@@ -134,8 +143,13 @@ def set_device(device):
def create_exllama_buffers(max_total_tokens: int): def create_exllama_buffers(max_total_tokens: int):
global FIXED_BYTES, LAYERS, DEVICE global LAYERS, DEVICE
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens) for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
for layer in LAYERS: for layer in LAYERS:
layer.post_init(temp_dq) layer.post_init(temp_dq)
...@@ -146,49 +160,48 @@ class QuantLinear(nn.Module): ...@@ -146,49 +160,48 @@ class QuantLinear(nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights""" """Linear layer implementation with per-group 4-bit quantization of the weights"""
# def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): def __init__(
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): self,
weight: Exl2Weight | GPTQWeight,
bias: torch.Tensor,
):
super().__init__() super().__init__()
if bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization."
)
self.q_handle = None self.q_handle = None
self.q_tensors = None self.q_tensors = weight
self.bits = bits self.extra_tensors = _ExtraTensors()
self.maxq = 2**self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32 if isinstance(weight, Exl2Weight):
self.outfeatures = qweight.shape[1] self.infeatures = weight.q_invperm.shape[0]
self.outfeatures = weight.q_weight.shape[1]
elif isinstance(weight, GPTQWeight):
if weight.bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
)
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
self.outfeatures = weight.qweight.shape[1]
self.padding = -self.outfeatures % 32 self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding self.outfeatures = self.outfeatures + self.padding
self.device = qweight.device self.device = weight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
self.group_size = groupsize
global FIXED_BYTES, LAYERS global LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
LAYERS.append(self) LAYERS.append(self)
def post_init(self, temp_dq): def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda" device = self.q_tensors.device
assert self.qweight.device.index is not None assert device.type == "cuda"
self.q_tensors = { assert device.index is not None
"qweight": self.qweight,
"qzeros": self.qzeros,
"scales": self.scales,
"g_idx": self.g_idx,
}
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you. # and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
def forward(self, x, force_cuda=False): def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
......
from typing import Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
...@@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize): ...@@ -151,15 +154,23 @@ def get_linear(weight, bias, quantize):
bias, bias,
quant_type="nf4", quant_type="nf4",
) )
elif quantize == "exl2":
if not isinstance(weight, Exl2Weight):
raise NotImplementedError(
f"The passed weight is not `exl2` compatible, loader needs to be updated."
)
from text_generation_server.layers.gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear(weight, bias)
elif quantize == "gptq": elif quantize == "gptq":
try: if not isinstance(weight, GPTQWeight):
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception:
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated." f"The passed weight is not `gptq` compatible, loader needs to be updated."
) )
if use_exllama: if weight.use_exllama:
try: try:
from text_generation_server.layers.gptq import ( from text_generation_server.layers.gptq import (
ExllamaQuantLinear, ExllamaQuantLinear,
...@@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize): ...@@ -169,25 +180,21 @@ def get_linear(weight, bias, quantize):
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
) )
linear = ExllamaQuantLinear( linear = ExllamaQuantLinear(weight, bias)
qweight, qzeros, scales, g_idx, bias, bits, groupsize
)
else: else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
linear = QuantLinear( linear = QuantLinear(
qweight, weight.qweight,
qzeros, weight.qzeros,
scales, weight.scales,
g_idx, weight.g_idx,
bias, bias,
bits, weight.bits,
groupsize, weight.groupsize,
) )
elif quantize == "awq": elif quantize == "awq":
try: if not isinstance(weight, GPTQWeight):
qweight, qzeros, scales, _, bits, groupsize, _ = weight
except Exception:
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated." f"The passed weight is not `awq` compatible, loader needs to be updated."
) )
...@@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize): ...@@ -200,11 +207,11 @@ def get_linear(weight, bias, quantize):
from text_generation_server.layers.awq.quantize.qmodule import WQLinear from text_generation_server.layers.awq.quantize.qmodule import WQLinear
linear = WQLinear( linear = WQLinear(
w_bit=bits, w_bit=weight.bits,
group_size=groupsize, group_size=weight.groupsize,
qweight=qweight, qweight=weight.qweight,
qzeros=qzeros, qzeros=weight.qzeros,
scales=scales, scales=weight.scales,
bias=bias is not None, bias=bias is not None,
) )
except ImportError: except ImportError:
......
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from typing import List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
class LayerConcat(torch.nn.Module):
"""
Apply multiple layers to the input and concatenate their
outputs.
"""
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
"""
`dim` is the dimension along which layer outputs are concatenated.
"""
super().__init__()
self.layers = layers
self.dim = dim
def forward(self, x: torch.Tensor):
outputs = [layer(x) for layer in self.layers]
return torch.cat(outputs, self.dim)
class SuperLayer(torch.nn.Module): class SuperLayer(torch.nn.Module):
...@@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer): ...@@ -21,7 +41,16 @@ class TensorParallelHead(SuperLayer):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
if weights.process_group.size() > 1: if config.quantize == "exl2":
try:
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize)
should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1:
try: try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True should_gather = True
...@@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer): ...@@ -37,8 +66,12 @@ class TensorParallelHead(SuperLayer):
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq"]: if config.quantize in ["gptq", "awq", "eetq"]:
quantize = None quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else: else:
quantize = config.quantize quantize = config.quantize
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None, quantize=quantize), get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group, process_group=weights.process_group,
...@@ -108,22 +141,35 @@ class TensorParallelColumnLinear(SuperLayer): ...@@ -108,22 +141,35 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
return cls.load_multi(config, [prefix], weights, bias, dim=0) weight = weights.get_weights_col(prefix, config.quantize)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
if bias: if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = torch.cat(b, dim=dim)
else: else:
bias = None bias = None
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias, config.quantize)
return cls(linear) return cls(linear)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
if config.quantize == "exl2":
linears = []
for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
class TensorParallelRowLinear(SuperLayer): class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group): def __init__(self, linear, process_group):
......
...@@ -263,7 +263,7 @@ def get_model( ...@@ -263,7 +263,7 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
if quantize in ["awq", "gptq"]: if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
else: else:
...@@ -402,12 +402,17 @@ def get_model( ...@@ -402,12 +402,17 @@ def get_model(
quantization_config = config_dict.get("quantization_config", None) quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None: if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None) method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq"}: if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}") logger.info(f"Auto selecting quantization method {method}")
quantize = method quantize = method
else: else:
logger.info(f"Unknown quantization method {method}") logger.info(f"Unknown quantization method {method}")
if quantize == "exl2" and sharded:
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)
if model_type == MAMBA: if model_type == MAMBA:
return Mamba( return Mamba(
model_id, model_id,
...@@ -881,6 +886,8 @@ def get_model( ...@@ -881,6 +886,8 @@ def get_model(
raise NotImplementedError("4bit quantization is not supported for AutoModel") raise NotImplementedError("4bit quantization is not supported for AutoModel")
elif quantize == "eetq": elif quantize == "eetq":
raise NotImplementedError("Eetq quantization is not supported for AutoModel") raise NotImplementedError("Eetq quantization is not supported for AutoModel")
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
model_id, model_id,
......
...@@ -21,6 +21,7 @@ from transformers.activations import ACT2FN ...@@ -21,6 +21,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger from loguru import logger
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu": if SYSTEM != "xpu":
...@@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights): ...@@ -256,7 +257,15 @@ def _load_gqa(config, prefix: str, weights):
else: else:
g_idx = None g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
else: else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop] q = qkv_slice[q_start:q_stop]
......
...@@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -395,7 +395,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.suffix", prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights, weights=weights,
) )
......
...@@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig): ...@@ -102,45 +102,6 @@ class MistralConfig(PretrainedConfig):
) )
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class MistralAttention(torch.nn.Module): class MistralAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
...@@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module): ...@@ -175,7 +136,13 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
......
...@@ -5,6 +5,7 @@ from torch import nn ...@@ -5,6 +5,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -90,8 +91,15 @@ def _load_multi_mqa_gptq( ...@@ -90,8 +91,15 @@ def _load_multi_mqa_gptq(
from text_generation_server.layers.gptq import HAS_EXLLAMA from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA weight = GPTQWeight(
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=HAS_EXLLAMA,
)
if bias: if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias") slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment