Unverified Commit c5de7cd8 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Add AWQ quantization inference support (#1019) (#1054)

# Add AWQ quantization inference support

Fixes
https://github.com/huggingface/text-generation-inference/issues/781

This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.

This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).

Quick way to test this PR would be bring up TGI as follows:

```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq

text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```

Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions

[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested. 
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released

[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).

Please refer to the linked issue for benchmarks for

[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs

[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).

Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.

## Who can review?

@OlivierDehaene OR @Narsil

---------



# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarAbhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: default avatarAbhinav Kulkarni <abhinav@concentric.ai>
parent fef36cea
...@@ -135,18 +135,21 @@ class Weights: ...@@ -135,18 +135,21 @@ class Weights:
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor already alternating Q,K,V within the main tensor
""" """
if quantize == "gptq": if quantize in ["gptq", "awq"]:
try: try:
qweight = self._get_qweight(f"{prefix}.qweight") qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" f"Cannot load `{quantize}` weight, make sure the model is already quantized."
) )
qzeros = self._get_qweight(f"{prefix}.qzeros") qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales") scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
g_idx = self.get_tensor(f"{prefix}.g_idx") if quantize == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
...@@ -171,14 +174,14 @@ class Weights: ...@@ -171,14 +174,14 @@ class Weights:
return weight return weight
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq": if quantize in ["gptq", "awq"]:
try: try:
qweight = torch.cat( qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" f"Cannot load `{quantize}` weight, make sure the model is already quantized"
) )
qzeros = torch.cat( qzeros = torch.cat(
...@@ -187,10 +190,14 @@ class Weights: ...@@ -187,10 +190,14 @@ class Weights:
scales = torch.cat( scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: if quantize == "gptq":
torch.testing.assert_close(w2, w[0]) w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
g_idx = w[0] for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
else:
g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
...@@ -281,6 +288,22 @@ class Weights: ...@@ -281,6 +288,22 @@ class Weights:
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
...@@ -322,4 +345,15 @@ class Weights: ...@@ -322,4 +345,15 @@ class Weights:
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
except Exception: except Exception:
pass filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
except Exception:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment