Commit 4594e6fa authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Add support for Marlin-quantized models

This change adds support for Marlin-quantized models. Marlin is an
FP16xINT4 matmul kernel, which provides good speedups decoding batches
of 16-32 tokens. It supports quantized models with symmetric
quantization, groupsize -1 or 128, and 4-bit.

Tested with:

- Llama 2
- Llama 3
- Phi 3
parent cf0d459a
......@@ -137,6 +137,13 @@ COPY server/Makefile-eetq Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels
FROM kernel-builder as marlin-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-marlin Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-marlin
# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
......@@ -205,6 +212,8 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/marlin/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
......
......@@ -64,6 +64,7 @@ Options:
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- marlin: 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0644531,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 5229,
"logprob": -1.2607422,
"special": false,
"text": " failed"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 6527,
"logprob": -0.11450195,
"special": false,
"text": " Could"
},
{
"id": 451,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 4511,
"logprob": -0.2286377,
"special": false,
"text": " connect"
},
{
"id": 304,
"logprob": 0.0,
"special": false,
"text": " to"
},
{
"id": 1923,
"logprob": -1.2568359,
"special": false,
"text": " server"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.15905762,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -0.21618652,
"special": false,
"text": "I"
}
],
"top_tokens": null
},
"generated_text": "Test request failed: Could not connect to server\n\nI"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -12.390625,
"text": "Test"
},
{
"id": 2009,
"logprob": -11.0625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0507812,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -2.3007812,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.0449219,
"special": false,
"text": "I"
},
{
"id": 505,
"logprob": -1.3242188,
"special": false,
"text": " have"
},
{
"id": 263,
"logprob": -0.2076416,
"special": false,
"text": " a"
},
{
"id": 1243,
"logprob": -2.0273438,
"special": false,
"text": " test"
},
{
"id": 2009,
"logprob": -0.6845703,
"special": false,
"text": " request"
},
{
"id": 515,
"logprob": -1.1748047,
"special": false,
"text": " from"
},
{
"id": 263,
"logprob": -1.0595703,
"special": false,
"text": " a"
},
{
"id": 1404,
"logprob": -1.5224609,
"special": false,
"text": " user"
}
],
"top_tokens": null
},
"generated_text": "\n\nI have a test request from a user"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_llama_marlin_handle(launcher):
with launcher(
"neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin_handle):
await flash_llama_marlin_handle.health(300)
return flash_llama_marlin_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
response = await flash_llama_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_load(
flash_llama_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
......@@ -64,6 +64,8 @@ enum Quantization {
/// triton kernel (wider support) when it's not.
/// AWQ has faster kernels.
Gptq,
/// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.
Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16.
#[deprecated(
......@@ -105,6 +107,9 @@ impl std::fmt::Display for Quantization {
Quantization::Gptq => {
write!(f, "gptq")
}
Quantization::Marlin => {
write!(f, "marlin")
}
Quantization::Awq => {
write!(f, "awq")
}
......
......@@ -3,6 +3,7 @@ include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-marlin
include Makefile-selective-scan
unit-tests:
......
marlin_commit := 2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c
marlin:
# Clone marlin
pip install packaging
git clone https://github.com/IST-DASLab/marlin.git marlin
build-marlin: marlin
cd marlin && git fetch && git checkout $(marlin_commit)
cd marlin && python setup.py build
install-marlin: build-marlin
cd marlin && python setup.py install
......@@ -21,6 +21,7 @@ class Quantization(str, Enum):
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8"
marlin = "marlin"
class Dtype(str, Enum):
......
......@@ -222,6 +222,14 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinLinear, MarlinWeight
if not isinstance(weight, MarlinWeight):
raise NotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
linear = MarlinLinear(B=weight.B, s=weight.s, bias=bias)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
try:
import marlin
except ImportError:
marlin = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
MARLIN_TILE_SIZE = 16
@dataclass
class MarlinWeight:
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
class MarlinLinear(nn.Module):
def __init__(
self, *, B: torch.Tensor, s: torch.Tensor, bias: Optional[torch.Tensor]
):
super().__init__()
if not has_sm_8_0:
raise NotImplementedError(
"Using quantized marlin models requires CUDA capability 8.0 or later"
)
if marlin is None:
raise NotImplementedError(
"You do not seem to have marlin installed, either install it (cd server && make install-marlin)"
)
assert B.dtype == torch.int32
assert s.dtype == torch.float16
in_features = B.shape[0] * MARLIN_TILE_SIZE
out_features = s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
group_size = -1 if s.shape[0] == 1 else in_features // s.shape[0]
assert group_size in {
-1,
128,
}, f"Group size must be -1 or 128, was {group_size}"
self.register_buffer("B", B)
self.register_buffer("s", s)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 128 * 16, dtype=torch.int, device=B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin is not None
C = torch.empty(
A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device
)
marlin.mul(
A.view((-1, A.shape[-1])),
self.B,
C.view((-1, C.shape[-1])),
self.s,
self.workspace,
)
if self.bias is not None:
C += self.bias
return C
......@@ -64,7 +64,7 @@ class TensorParallelHead(SuperLayer):
should_gather = False
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq"]:
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
......
......@@ -260,7 +260,7 @@ def get_model(
) -> Model:
global FLASH_ATTENTION
if dtype is None:
if quantize in ["awq", "exl2", "gptq"]:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params.
dtype = torch.float16
else:
......
......@@ -271,6 +271,11 @@ def _load_gqa(config, prefix: str, weights):
groupsize=groupsize,
use_exllama=use_exllama,
)
elif config.quantize == "marlin":
# NOTE: at the time marlin support was added, the only model that
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
# but it requires manual concatenation of weight files.
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
......
......@@ -145,7 +145,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
......
......@@ -46,6 +46,10 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
prefix,
weights,
)
elif config.quantize == "marlin":
raise RuntimeError(
"GPT-2 models with marlin quantization are not yet supported"
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
......
......@@ -139,7 +139,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
......@@ -89,7 +89,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
......@@ -46,7 +46,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment