Unverified Commit aefde28b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)

Let's start discussing implementation.

- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).

Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.

My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
parent bd3a9d8e
......@@ -159,6 +159,11 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
g++ \
&& rm -rf /var/lib/apt/lists/*
# AWS Sagemaker compatbile image
FROM base as sagemaker
......
This diff is collapsed.
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
deprecated==1.2.13 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.0 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.5.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.0 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
......@@ -26,17 +26,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.1 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.2 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
......@@ -151,5 +151,37 @@ def download_weights(
utils.convert_files(local_pt_files, local_st_files)
@app.command()
def quantize(
model_id: str,
output_dir: str,
revision: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01,
act_order: bool = False,
):
download_weights(
model_id=model_id,
revision=revision,
logger_level=logger_level,
json_output=json_output,
)
from text_generation_server.utils.gptq.quantize import quantize
quantize(
model_id=model_id,
bits=4,
groupsize=128,
output_dir=output_dir,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
)
if __name__ == "__main__":
app()
......@@ -246,6 +246,10 @@ def get_model(
if sharded:
raise ValueError("sharded is not supported for AutoModel")
if quantize == "gptq":
raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
......
......@@ -42,7 +42,8 @@ from text_generation_server.utils.layers import (
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
......@@ -57,19 +58,21 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = (
weight.view(
num_heads,
3,
head_size,
hidden_size,
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0)
if isinstance(weight, torch.Tensor):
# Only on non quantized versions
weight = (
weight.view(
num_heads,
3,
head_size,
hidden_size,
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
......
......@@ -21,7 +21,8 @@ from text_generation_server.utils.layers import (
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
......
......@@ -21,6 +21,81 @@ from text_generation_server.utils.layers import (
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if config.quantize == "gptq":
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size()
rank = weights.process_group.rank()
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert 2 * head_size % (32 // 4) == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
......@@ -92,7 +167,9 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=0
)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
......@@ -105,7 +182,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
......
......@@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional
from text_generation_server.models import FlashCausalLM
......@@ -34,13 +34,22 @@ class FlashLlama(FlashCausalLM):
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
......
# https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
import builtins
import math
import time
from typing import Dict
import triton
class Autotuner(triton.KernelInterface):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
"""
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = (
prune_configs_by["perf_model"],
prune_configs_by["top_k"],
)
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current,
)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40
)
except triton.compiler.OutOfResources:
return (float("inf"), float("inf"), float("inf"))
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
:top_k
]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs["GROUP_SIZE_M"]
if (
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
) in used:
continue
used.add(
(
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
)
)
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(
scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
except:
print("triton not installed.")
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // 4
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
This diff is collapsed.
......@@ -15,6 +15,8 @@ except ImportError:
from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
# Monkey patching
@classmethod
......@@ -129,7 +131,22 @@ def get_linear(weight, bias, quantize):
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
raise NotImplementedError("Soon")
try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
bias,
bits,
groupsize,
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
......@@ -152,8 +169,14 @@ class TensorParallelHead(SuperLayer):
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
# GPTQ doesn't quantize heads (nor embeddings)
if config.quantize == "gptq":
quantize = None
else:
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None, quantize=config.quantize),
get_linear(weight, bias=None, quantize=quantize),
process_group=weights.process_group,
)
......@@ -196,24 +219,21 @@ class TensorParallelHead(SuperLayer):
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
return cls.load_multi(config, [prefix], weights, bias, dim=0)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=0)
bias = torch.cat(b, dim=dim)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
linear = get_linear(weight, bias, config.quantize)
return cls(linear)
class TensorParallelRowLinear(SuperLayer):
......@@ -223,7 +243,8 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
......
from pathlib import Path
from typing import List, Dict, Optional
from safetensors import safe_open
import torch
class Weights:
......@@ -54,7 +55,10 @@ class Weights:
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
tensor = tensor.to(dtype=self.dtype)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
......@@ -80,6 +84,49 @@ class Weights:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment