Unverified Commit 1b0af2d3 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #40 from casper-hansen/new_kernel

[NEW] GEMV kernel implementation
parents 84fb7e98 f264ebb3
......@@ -7,6 +7,7 @@
AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up models by 2x while reducing memory requirements by 3x compared to FP16. AutoAWQ implements the Activation-aware Weight Quantization (AWQ) algorithm for quantizing LLMs. AutoAWQ was created and improved upon from the [original work](https://github.com/mit-han-lab/llm-awq) from MIT.
*Latest News* 🔥
- [2023/09] 1.6x-2.5x speed boost on fused models (now including MPT and Falcon).
- [2023/09] Multi-GPU support, bug fixes, and better benchmark scripts available
- [2023/08] PyPi package released and AutoModel class available
......@@ -42,6 +43,8 @@ pip install autoawq
<summary>Build AutoAWQ from scratch</summary>
Build time can take 10 minutes. Download your model while you install AutoAWQ.
```
git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ
......@@ -67,9 +70,23 @@ The detailed support list:
## Usage
Below, you will find examples of how to easily quantize a model and run inference.
Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models.
### INT4 GEMM vs INT4 GEMV vs FP16
There are two versions of AWQ: GEMM and GEMV. Both names to how matrix multiplication runs under the hood. We suggest the following:
- GEMV (quantized): Best for small context, batch size 1, highest number of tokens/s.
- GEMM (quantized): Best for larger context, up to batch size 8, faster than GEMV on batch size > 1, slower than GEMV on batch size = 1.
- FP16 (non-quantized): Best for large batch sizes of 8 or larger, highest throughput. We recommend [TGI](https://github.com/huggingface/text-generation-inference) or [vLLM](https://github.com/vllm-project/vllm).
### Examples
### Quantization
<details>
<summary>Quantization</summary>
Expect this to take 10-15 minutes on smaller 7B models, and around 1 hour for 70B models.
```python
from awq import AutoAWQForCausalLM
......@@ -91,60 +108,140 @@ model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
```
### Inference
</details>
<details>
Run inference on a quantized model from Huggingface:
<summary>Inference</summary>
```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/vicuna-7b-v1.5-awq"
quant_file = "awq_model_w4_g128.pt"
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file)
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
model.generate(...)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: {prompt}
ASSISTANT:"""
tokens = tokenizer(
prompt_template.format(prompt="How are you today?"),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
```
## Benchmarks
Benchmark speeds may vary from server to server and that it also depends on your CPU. If you want to minimize latency, you should rent a GPU/CPU combination that has high memory bandwidth for both and high single-core speed for CPU.
| Model | GPU | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----:|:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 4090 | 19.97 | 8.66 | 2.31x |
| LLaMA-2-13B | 4090 | OOM | 13.54 | -- |
| Vicuna-7B | 4090 | 19.09 | 8.61 | 2.22x |
| Vicuna-13B | 4090 | OOM | 12.17 | -- |
| MPT-7B | 4090 | 17.09 | 12.58 | 1.36x |
| MPT-30B | 4090 | OOM | 23.54 | -- |
| Falcon-7B | 4090 | 29.91 | 19.84 | 1.51x |
| LLaMA-2-7B | A6000 | 27.14 | 12.44 | 2.18x |
| LLaMA-2-13B | A6000 | 47.28 | 20.28 | 2.33x |
| Vicuna-7B | A6000 | 26.06 | 12.43 | 2.10x |
| Vicuna-13B | A6000 | 44.91 | 17.30 | 2.60x |
| MPT-7B | A6000 | 22.79 | 16.87 | 1.35x |
| MPT-30B | A6000 | OOM | 31.57 | -- |
| Falcon-7B | A6000 | 39.44 | 27.34 | 1.44x |
</details>
<details>
<summary>Detailed benchmark (CPU vs. GPU)</summary>
<summary>AutoAWQForCausalLM.from_quantized</summary>
Here is the difference between a fast and slow CPU on MPT-7B:
- `quant_path`: Path to folder containing model files.
- `quant_filename`: The filename to model weights or `index.json` file.
- `max_new_tokens`: The max sequence length, used to allocate kv-cache for fused models.
- `fuse_layers`: Whether or not to use fused layers.
- `batch_size`: The batch size to initialize the AWQ model with.
RTX 4090 + Intel i9 13900K (2 different VMs):
- CUDA 12.0, Driver 525.125.06: 134 tokens/s (7.46 ms/token)
- CUDA 12.0, Driver 525.125.06: 117 tokens/s (8.52 ms/token)
</details>
RTX 4090 + AMD EPYC 7-Series (3 different VMs):
- CUDA 12.2, Driver 535.54.03: 53 tokens/s (18.6 ms/token)
- CUDA 12.2, Driver 535.54.03: 56 tokens/s (17.71 ms/token)
- CUDA 12.0, Driver 525.125.06: 55 tokens/ (18.15 ms/token)
## Benchmarks
</details>
### Vicuna 7B (LLaMa-2)
- Note: Blazing fast generation, slow context processing
- GPU: NVIDIA GeForce RTX 3090
- Version: GEMV
- Command: `python examples/benchmark.py --model_path casperhansen/vicuna-7b-v1.5-awq-gemv`
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 231.393 | 153.632 | 4.66 GB (19.68%) |
| 1 | 64 | 64 | 233.909 | 154.475 | 4.66 GB (19.68%) |
| 1 | 128 | 128 | 233.145 | 152.133 | 4.66 GB (19.68%) |
| 1 | 256 | 256 | 228.562 | 147.692 | 4.67 GB (19.72%) |
| 1 | 512 | 512 | 228.914 | 139.179 | 4.80 GB (20.26%) |
| 1 | 1024 | 1024 | 227.393 | 125.058 | 5.56 GB (23.48%) |
| 1 | 2048 | 2048 | 225.736 | 123.228 | 8.08 GB (34.09%) |
- Note: Fast generation, fast context processing
- GPU: NVIDIA GeForce RTX 3090
- Version: GEMM
- Command: `python examples/benchmark.py --model_path casperhansen/vicuna-7b-v1.5-awq`
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 521.444 | 126.51 | 4.55 GB (19.21%) |
| 1 | 64 | 64 | 2618.88 | 125.428 | 4.57 GB (19.31%) |
| 1 | 128 | 128 | 2808.09 | 123.865 | 4.61 GB (19.44%) |
| 1 | 256 | 256 | 2807.46 | 120.779 | 4.67 GB (19.72%) |
| 1 | 512 | 512 | 2769.9 | 115.08 | 4.80 GB (20.26%) |
| 1 | 1024 | 1024 | 2640.95 | 105.493 | 5.56 GB (23.48%) |
| 1 | 2048 | 2048 | 2341.36 | 104.188 | 8.08 GB (34.09%) |
### MPT 7B
- Note: Blazing fast generation, slow context processing
- GPU: NVIDIA GeForce RTX 3090
- Command: `python examples/benchmark.py --model_path casperhansen/mpt-7b-8k-chat-awq-gemv`
- Version: GEMV
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 187.332 | 136.765 | 3.65 GB (15.42%) |
| 1 | 64 | 64 | 241.026 | 136.476 | 3.67 GB (15.48%) |
| 1 | 128 | 128 | 239.44 | 137.599 | 3.70 GB (15.61%) |
| 1 | 256 | 256 | 233.184 | 137.02 | 3.76 GB (15.88%) |
| 1 | 512 | 512 | 233.082 | 135.633 | 3.89 GB (16.41%) |
| 1 | 1024 | 1024 | 231.504 | 122.197 | 4.40 GB (18.57%) |
| 1 | 2048 | 2048 | 228.307 | 121.468 | 5.92 GB (24.98%) |
- Note: Fast generation, fast context processing
- GPU: NVIDIA GeForce RTX 3090
- Version: GEMM
- Command: `python examples/benchmark.py --model_path casperhansen/mpt-7b-8k-chat-awq`
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 557.714 | 118.567 | 3.65 GB (15.42%) |
| 1 | 64 | 64 | 2752.9 | 120.772 | 3.67 GB (15.48%) |
| 1 | 128 | 128 | 2982.67 | 119.52 | 3.70 GB (15.61%) |
| 1 | 256 | 256 | 3009.16 | 116.911 | 3.76 GB (15.88%) |
| 1 | 512 | 512 | 2901.91 | 111.607 | 3.95 GB (16.68%) |
| 1 | 1024 | 1024 | 2718.68 | 102.623 | 4.40 GB (18.57%) |
| 1 | 2048 | 2048 | 2363.61 | 101.368 | 5.92 GB (24.98%) |
### Falcon 7B
- Note: Fast generation, fast context processing
- GPU: NVIDIA GeForce RTX 3090
- Command: `python examples/benchmark.py --model_path casperhansen/falcon-7b-awq --quant_file awq_model_w4_g64.pt`
- Version: GEMM
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 466.826 | 95.1413 | 4.47 GB (18.88%) |
| 1 | 64 | 64 | 1920.61 | 94.5963 | 4.48 GB (18.92%) |
| 1 | 128 | 128 | 2406.1 | 94.793 | 4.48 GB (18.92%) |
| 1 | 256 | 256 | 2521.08 | 94.1144 | 4.48 GB (18.92%) |
| 1 | 512 | 512 | 2478.28 | 93.4123 | 4.48 GB (18.92%) |
| 1 | 1024 | 1024 | 2256.22 | 94.0237 | 4.69 GB (19.78%) |
| 1 | 2048 | 2048 | 1831.71 | 94.2032 | 6.83 GB (28.83%) |
## Reference
......
import os
from transformers import AutoConfig
from awq.models import *
from awq.models.base import BaseAWQForCausalLM
......@@ -35,7 +36,9 @@ class AutoAWQForCausalLM:
@classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True) -> BaseAWQForCausalLM:
device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
......
......@@ -2,15 +2,17 @@ import os
import gc
import json
import torch
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
......@@ -41,6 +43,7 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
......@@ -51,7 +54,7 @@ class BaseAWQForCausalLM(nn.Module):
self.is_quantized = True
@staticmethod
def fuse_layers(model):
def fuse_layers(model, quant_config):
pass
def _awq_quant(self):
......@@ -70,18 +73,23 @@ class BaseAWQForCausalLM(nn.Module):
module.weight.data, scales, zeros = pseudo_quantize_tensor(
module.weight.data,
get_scale_zp=True,
**self.quant_config
w_bit=self.quant_config["w_bit"],
q_group_size=self.quant_config["q_group_size"]
)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
if self.quant_config["version"] == 'GEMM':
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
......@@ -253,7 +261,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True, fuse_layers=False):
safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
......@@ -273,9 +281,12 @@ class BaseAWQForCausalLM(nn.Module):
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
if "version" not in quant_config.keys():
quant_config["version"] = version
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
......@@ -293,7 +304,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config)
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights()
......@@ -313,7 +324,7 @@ class BaseAWQForCausalLM(nn.Module):
)
if fuse_layers:
self.fuse_layers(model)
self.fuse_layers(model, quant_config)
else:
# If not quantized, must load with AutoModelForCausalLM
......@@ -333,7 +344,7 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, quant_config):
def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
......@@ -351,8 +362,17 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
module, quant_config['w_bit'], quant_config['q_group_size'], True)
if version == 'GEMM':
q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......
from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer"
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: FalconForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: FalconDecoderLayer):
def get_act_for_scaling(module: OldFalconDecoderLayer):
return dict(
is_scalable=True,
scale_name="mlp.act",
......@@ -22,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
@staticmethod
def get_layers_for_scaling(module: FalconDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs):
layers = []
# Falcon 7B (older architecture)
......@@ -56,4 +61,48 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
kwargs=module_kwargs,
))
return layers
\ No newline at end of file
return layers
from awq.modules.fused.model import FalconModel
from awq.modules.fused.block import FalconDecoderLayer
class FalconFuser:
def __init__(self, model: FalconForCausalLM):
self.model = model
def fuse_transformer(self):
blocks = []
module: OldFalconDecoderLayer
for module in self.model.transformer.h:
if module.config.num_attention_heads == 71:
input_layernorm = module.input_layernorm
ln_attn = None
ln_mlp = None
new_decoder_arch = False
else:
input_layernorm = None
ln_attn = module.ln_attn
ln_mlp = module.ln_mlp
new_decoder_arch = True
blocks.append(FalconDecoderLayer(
hidden_size=module.config.hidden_size,
n_heads=module.config.num_attention_heads,
qkv_layer=module.self_attention.query_key_value,
o_proj=module.self_attention.dense,
mlp=module.mlp,
dev=next(iter(module.state_dict().values())).device,
max_seq_len=self.model.config.max_new_tokens,
input_layernorm=input_layernorm,
ln_attn=ln_attn,
ln_mlp=ln_mlp,
new_decoder_arch=new_decoder_arch
))
self.model.transformer = FalconModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.word_embeddings,
self.model.transformer.ln_f,
)
\ No newline at end of file
......@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: LlamaForCausalLM):
fuser = LlamaFuser(model)
def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
fuser = LlamaFuser(model, quant_config)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
......@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
def __init__(self, model):
def __init__(self, model, quant_config):
self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
......@@ -95,12 +96,11 @@ class LlamaFuser:
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module)
attn = QuantLlamaAttention(
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
......@@ -108,24 +108,33 @@ class LlamaFuser:
set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: LlamaAttention):
# get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
qkv_layer = WQLinear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)
# replace buffers with real weights
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
......
from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(model: MptForCausalLM):
def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model)
fuser.fuse_mlp()
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks
@staticmethod
def get_act_for_scaling(module: MptBlock):
def get_act_for_scaling(module: OldMptBlock):
return dict(
is_scalable=True,
scale_name="ffn.act",
......@@ -29,7 +29,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod
def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = []
# attention input
......@@ -67,24 +67,38 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel
class MptFuser:
def __init__(self, model):
def __init__(self, model: MptForCausalLM):
self.model = model
self.mlp_modules: List[Tuple[str, MptMLP]] = [
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP)
if 'mptblock' in module.__class__.__name__.lower()
]
def fuse_attention(self):
pass
def fuse_layernorm(self):
pass
def fuse_transformer(self):
blocks = []
module: OldMptBlock
for module in self.model.transformer.blocks:
blocks.append(MPTBlock(
self.model.config.d_model,
self.model.config.n_heads,
module.attn.Wqkv,
module.attn.out_proj,
module.ffn,
module.norm_1,
module.norm_2,
next(iter(module.state_dict().values())).device,
self.model.config.max_new_tokens
))
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
self.model.transformer = MPTModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.wte,
self.model.transformer.norm_f,
)
\ No newline at end of file
from .fused_norm import *
from .fused_attn import *
from .fused_mlp import *
import torch.nn as nn
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
import os
import math
import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
def build_alibi_bias(
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
if full:
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1
)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
positions,
query,
key,
self.dim,
self.cos_sin_cache
)
return query, key
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = self.hidden_size // num_heads
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_local_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8),
"xq_view": (self.n_local_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim)
}
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def forward(
self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
values_store = xv.transpose(2, 1)
keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4)
.contiguous()
)
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
keys = xk
values = xv
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
# xq = xq[:, 0, :, :]
# xk = xk[:, 0, :, :]
# xv = xv[:, 0, :, :]
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
past_key_value = (xk, xv) if use_cache else None
attention_weight = awq_inference_engine.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache_k, # key cache
self.cache_v, # value cache
None, # length per sample
self.alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
self.is_neox, # is neox
)
attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight)
if use_cache:
self.start_pos += seqlen
else:
self.start_pos = 0
return attn_output, attention_weight, past_key_value
\ No newline at end of file
import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h = hidden_states + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes
).to(dev)
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
else:
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch:
kv_heads = 8
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, n_heads+(kv_heads*2), max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, n_heads+(kv_heads*2), head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads+(kv_heads*2), head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :,0],
"xk_slice": lambda xqkv: xqkv[:, :, :,1],
"xv_slice": lambda xqkv: xqkv[:, :, :,2],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (1, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
return self.attention_shapes
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h_attn = hidden_states + attn_output
if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out)
else:
h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp
return out, None, past_key_value
import torch
import torch.nn as nn
import awq_inference_engine
import torch.nn.functional as F
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
return self.down_proj(self.act(x))
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantLlamaMLP(nn.Module):
......@@ -31,7 +9,7 @@ class QuantLlamaMLP(nn.Module):
self,
gate_proj,
down_proj,
up_proj,
up_proj
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
......@@ -47,19 +25,32 @@ class QuantLlamaMLP(nn.Module):
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x):
return self.down_proj(self.our_llama_mlp(x))
def our_llama_mlp(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size, )
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemm_forward_cuda(
x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, 8
gate_output = self.linear(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.group_size,
)
gate_output = F.silu(gate_output)
up_output = awq_inference_engine.gemm_forward_cuda(
x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8
up_output = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size,
)
c = gate_output * up_output
c = c.reshape(out_shape)
return c
x = F.silu(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
return x
\ No newline at end of file
import torch
import torch.nn as nn
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f):
super().__init__()
self.vocab_size = vocab_size
self.wte = wte
self.blocks: list[MPTBlock] = nn.ModuleList(blocks)
self.norm_f = norm_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
_bsz, seqlen = input_ids.shape
h = self.wte(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
super().__init__()
self.vocab_size = vocab_size
self.word_embeddings = word_embeddings
self.blocks: list[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
# NOTE: falcon input ids contain full context
# after context is processed, slice to latest token
if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1:
input_ids = input_ids[:, self.blocks[0].attn.start_pos:]
_bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaRotaryEmbedding
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# [max_position, rot_dim]
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding(
positions,
query,
key,
self.dim,
self.cos_sin_cache,
True # is_neox
)
return query, key
class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
num_kv_heads,
qkv_proj,
o_proj,
dev,
max_new_tokens,
use_hf_rotary=False
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
if use_hf_rotary:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_new_tokens, device=dev)
else:
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
if self.use_hf_rotary:
# get qkv
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
query, key, value = torch.split(qkv_states, 1, dim=2)
del qkv_states
# reshape for hf rotary
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
else:
# get qkv
query, key, value = qkv_states.chunk(chunks=3, dim=-1)
del qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size, query_len, _ = query.shape
query = query.view(query_len*query_batch_size, self.num_heads * self.head_dim)
# [num_tokens, num_kv_heads * head_size]
key_batch_size, key_len, _ = key.shape
key = key.view(key_len*key_batch_size, self.num_kv_heads * self.head_dim)
# [num_tokens]
positions = position_ids.view(-1).to(query.device)
query, key = self.rotary_emb(query, key, positions)
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
value = value.to(key.device)
if past_key_value is not None:
# reuse k, v, self_attention
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
if use_cache:
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
past_key_value = (key, value) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
del query, key, value
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
......@@ -4,17 +4,24 @@ import torch.nn as nn
import awq_inference_engine # with CUDA kernels
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
class WQLinear(nn.Module):
class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
......@@ -25,6 +32,7 @@ class WQLinear(nn.Module):
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
......@@ -74,7 +82,7 @@ class WQLinear(nn.Module):
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device)
for col in range(zeros.shape[1] // pack_num):
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
......@@ -97,3 +105,100 @@ class WQLinear(nn.Module):
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
class WQLinear_GEMV(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = (32 // self.w_bit)
self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev))
if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
else:
self.bias = None
@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num),
dtype=torch.float16,
device=scales.device
)
qscales[:, :scales.shape[1]] = scales
awq_linear.scales = qscales
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros(
(zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
dtype=torch.int32,
device=zeros.device,
)
for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
if col * pack_num + order_map[i] >= zeros.shape[1]:
continue
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
\ No newline at end of file
......@@ -43,7 +43,7 @@ def auto_clip_layer(w,
max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, **quant_config)
q_w = pseudo_quantize_tensor(cur_w, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"])
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
......
......@@ -7,7 +7,7 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation
from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"]
......@@ -98,7 +98,7 @@ def auto_scale_block(awq_model,
from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(p, **quant_config).detach()
def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
else:
def w_quantize_func(p): return p
......
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace fastertransformer {
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t; t.x = x; t.y = y; return t;
}
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace fastertransformer
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template<typename T, typename KERNEL_PARAMS_TYPE>
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head) {
case 32:
mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 48:
mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 64:
mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 80:
mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 96:
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 112:
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 160:
mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 192:
mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 224:
mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 256:
mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
default:
assert(false);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment