Unverified Commit 1b0af2d3 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #40 from casper-hansen/new_kernel

[NEW] GEMV kernel implementation
parents 84fb7e98 f264ebb3
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up models by 2x while reducing memory requirements by 3x compared to FP16. AutoAWQ implements the Activation-aware Weight Quantization (AWQ) algorithm for quantizing LLMs. AutoAWQ was created and improved upon from the [original work](https://github.com/mit-han-lab/llm-awq) from MIT. AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up models by 2x while reducing memory requirements by 3x compared to FP16. AutoAWQ implements the Activation-aware Weight Quantization (AWQ) algorithm for quantizing LLMs. AutoAWQ was created and improved upon from the [original work](https://github.com/mit-han-lab/llm-awq) from MIT.
*Latest News* 🔥 *Latest News* 🔥
- [2023/09] 1.6x-2.5x speed boost on fused models (now including MPT and Falcon).
- [2023/09] Multi-GPU support, bug fixes, and better benchmark scripts available - [2023/09] Multi-GPU support, bug fixes, and better benchmark scripts available
- [2023/08] PyPi package released and AutoModel class available - [2023/08] PyPi package released and AutoModel class available
...@@ -42,6 +43,8 @@ pip install autoawq ...@@ -42,6 +43,8 @@ pip install autoawq
<summary>Build AutoAWQ from scratch</summary> <summary>Build AutoAWQ from scratch</summary>
Build time can take 10 minutes. Download your model while you install AutoAWQ.
``` ```
git clone https://github.com/casper-hansen/AutoAWQ git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ cd AutoAWQ
...@@ -67,9 +70,23 @@ The detailed support list: ...@@ -67,9 +70,23 @@ The detailed support list:
## Usage ## Usage
Below, you will find examples of how to easily quantize a model and run inference. Under examples, you can find examples of how to quantize, run inference, and benchmark AutoAWQ models.
### INT4 GEMM vs INT4 GEMV vs FP16
There are two versions of AWQ: GEMM and GEMV. Both names to how matrix multiplication runs under the hood. We suggest the following:
- GEMV (quantized): Best for small context, batch size 1, highest number of tokens/s.
- GEMM (quantized): Best for larger context, up to batch size 8, faster than GEMV on batch size > 1, slower than GEMV on batch size = 1.
- FP16 (non-quantized): Best for large batch sizes of 8 or larger, highest throughput. We recommend [TGI](https://github.com/huggingface/text-generation-inference) or [vLLM](https://github.com/vllm-project/vllm).
### Examples
### Quantization <details>
<summary>Quantization</summary>
Expect this to take 10-15 minutes on smaller 7B models, and around 1 hour for 70B models.
```python ```python
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
...@@ -91,60 +108,140 @@ model.save_quantized(quant_path) ...@@ -91,60 +108,140 @@ model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path) tokenizer.save_pretrained(quant_path)
``` ```
### Inference </details>
<details>
Run inference on a quantized model from Huggingface: <summary>Inference</summary>
```python ```python
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/vicuna-7b-v1.5-awq" quant_path = "casperhansen/vicuna-7b-v1.5-awq"
quant_file = "awq_model_w4_g128.pt" quant_file = "awq_model_w4_g128.pt"
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file) # Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model.generate(...)
# Convert prompt to tokens
prompt_template = """\
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
USER: {prompt}
ASSISTANT:"""
tokens = tokenizer(
prompt_template.format(prompt="How are you today?"),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
``` ```
## Benchmarks </details>
Benchmark speeds may vary from server to server and that it also depends on your CPU. If you want to minimize latency, you should rent a GPU/CPU combination that has high memory bandwidth for both and high single-core speed for CPU.
| Model | GPU | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----:|:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 4090 | 19.97 | 8.66 | 2.31x |
| LLaMA-2-13B | 4090 | OOM | 13.54 | -- |
| Vicuna-7B | 4090 | 19.09 | 8.61 | 2.22x |
| Vicuna-13B | 4090 | OOM | 12.17 | -- |
| MPT-7B | 4090 | 17.09 | 12.58 | 1.36x |
| MPT-30B | 4090 | OOM | 23.54 | -- |
| Falcon-7B | 4090 | 29.91 | 19.84 | 1.51x |
| LLaMA-2-7B | A6000 | 27.14 | 12.44 | 2.18x |
| LLaMA-2-13B | A6000 | 47.28 | 20.28 | 2.33x |
| Vicuna-7B | A6000 | 26.06 | 12.43 | 2.10x |
| Vicuna-13B | A6000 | 44.91 | 17.30 | 2.60x |
| MPT-7B | A6000 | 22.79 | 16.87 | 1.35x |
| MPT-30B | A6000 | OOM | 31.57 | -- |
| Falcon-7B | A6000 | 39.44 | 27.34 | 1.44x |
<details> <details>
<summary>Detailed benchmark (CPU vs. GPU)</summary> <summary>AutoAWQForCausalLM.from_quantized</summary>
Here is the difference between a fast and slow CPU on MPT-7B: - `quant_path`: Path to folder containing model files.
- `quant_filename`: The filename to model weights or `index.json` file.
- `max_new_tokens`: The max sequence length, used to allocate kv-cache for fused models.
- `fuse_layers`: Whether or not to use fused layers.
- `batch_size`: The batch size to initialize the AWQ model with.
RTX 4090 + Intel i9 13900K (2 different VMs): </details>
- CUDA 12.0, Driver 525.125.06: 134 tokens/s (7.46 ms/token)
- CUDA 12.0, Driver 525.125.06: 117 tokens/s (8.52 ms/token)
RTX 4090 + AMD EPYC 7-Series (3 different VMs): ## Benchmarks
- CUDA 12.2, Driver 535.54.03: 53 tokens/s (18.6 ms/token)
- CUDA 12.2, Driver 535.54.03: 56 tokens/s (17.71 ms/token)
- CUDA 12.0, Driver 525.125.06: 55 tokens/ (18.15 ms/token)
</details> ### Vicuna 7B (LLaMa-2)
- Note: Blazing fast generation, slow context processing
- GPU: NVIDIA GeForce RTX 3090
- Version: GEMV
- Command: `python examples/benchmark.py --model_path casperhansen/vicuna-7b-v1.5-awq-gemv`
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 231.393 | 153.632 | 4.66 GB (19.68%) |
| 1 | 64 | 64 | 233.909 | 154.475 | 4.66 GB (19.68%) |
| 1 | 128 | 128 | 233.145 | 152.133 | 4.66 GB (19.68%) |
| 1 | 256 | 256 | 228.562 | 147.692 | 4.67 GB (19.72%) |
| 1 | 512 | 512 | 228.914 | 139.179 | 4.80 GB (20.26%) |
| 1 | 1024 | 1024 | 227.393 | 125.058 | 5.56 GB (23.48%) |
| 1 | 2048 | 2048 | 225.736 | 123.228 | 8.08 GB (34.09%) |
- Note: Fast generation, fast context processing
- GPU: NVIDIA GeForce RTX 3090
- Version: GEMM
- Command: `python examples/benchmark.py --model_path casperhansen/vicuna-7b-v1.5-awq`
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 521.444 | 126.51 | 4.55 GB (19.21%) |
| 1 | 64 | 64 | 2618.88 | 125.428 | 4.57 GB (19.31%) |
| 1 | 128 | 128 | 2808.09 | 123.865 | 4.61 GB (19.44%) |
| 1 | 256 | 256 | 2807.46 | 120.779 | 4.67 GB (19.72%) |
| 1 | 512 | 512 | 2769.9 | 115.08 | 4.80 GB (20.26%) |
| 1 | 1024 | 1024 | 2640.95 | 105.493 | 5.56 GB (23.48%) |
| 1 | 2048 | 2048 | 2341.36 | 104.188 | 8.08 GB (34.09%) |
### MPT 7B
- Note: Blazing fast generation, slow context processing
- GPU: NVIDIA GeForce RTX 3090
- Command: `python examples/benchmark.py --model_path casperhansen/mpt-7b-8k-chat-awq-gemv`
- Version: GEMV
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 187.332 | 136.765 | 3.65 GB (15.42%) |
| 1 | 64 | 64 | 241.026 | 136.476 | 3.67 GB (15.48%) |
| 1 | 128 | 128 | 239.44 | 137.599 | 3.70 GB (15.61%) |
| 1 | 256 | 256 | 233.184 | 137.02 | 3.76 GB (15.88%) |
| 1 | 512 | 512 | 233.082 | 135.633 | 3.89 GB (16.41%) |
| 1 | 1024 | 1024 | 231.504 | 122.197 | 4.40 GB (18.57%) |
| 1 | 2048 | 2048 | 228.307 | 121.468 | 5.92 GB (24.98%) |
- Note: Fast generation, fast context processing
- GPU: NVIDIA GeForce RTX 3090
- Version: GEMM
- Command: `python examples/benchmark.py --model_path casperhansen/mpt-7b-8k-chat-awq`
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 557.714 | 118.567 | 3.65 GB (15.42%) |
| 1 | 64 | 64 | 2752.9 | 120.772 | 3.67 GB (15.48%) |
| 1 | 128 | 128 | 2982.67 | 119.52 | 3.70 GB (15.61%) |
| 1 | 256 | 256 | 3009.16 | 116.911 | 3.76 GB (15.88%) |
| 1 | 512 | 512 | 2901.91 | 111.607 | 3.95 GB (16.68%) |
| 1 | 1024 | 1024 | 2718.68 | 102.623 | 4.40 GB (18.57%) |
| 1 | 2048 | 2048 | 2363.61 | 101.368 | 5.92 GB (24.98%) |
### Falcon 7B
- Note: Fast generation, fast context processing
- GPU: NVIDIA GeForce RTX 3090
- Command: `python examples/benchmark.py --model_path casperhansen/falcon-7b-awq --quant_file awq_model_w4_g64.pt`
- Version: GEMM
| Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|-------------:|-----------------:|----------------:|-------------------:|------------------:|:-----------------|
| 1 | 32 | 32 | 466.826 | 95.1413 | 4.47 GB (18.88%) |
| 1 | 64 | 64 | 1920.61 | 94.5963 | 4.48 GB (18.92%) |
| 1 | 128 | 128 | 2406.1 | 94.793 | 4.48 GB (18.92%) |
| 1 | 256 | 256 | 2521.08 | 94.1144 | 4.48 GB (18.92%) |
| 1 | 512 | 512 | 2478.28 | 93.4123 | 4.48 GB (18.92%) |
| 1 | 1024 | 1024 | 2256.22 | 94.0237 | 4.69 GB (19.78%) |
| 1 | 2048 | 2048 | 1831.71 | 94.2032 | 6.83 GB (28.83%) |
## Reference ## Reference
......
import os
from transformers import AutoConfig from transformers import AutoConfig
from awq.models import * from awq.models import *
from awq.models.base import BaseAWQForCausalLM from awq.models.base import BaseAWQForCausalLM
...@@ -35,7 +36,9 @@ class AutoAWQForCausalLM: ...@@ -35,7 +36,9 @@ class AutoAWQForCausalLM:
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None, def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True) -> BaseAWQForCausalLM: device='balanced', trust_remote_code=True, fuse_layers=True,
batch_size=1) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
......
...@@ -2,15 +2,17 @@ import os ...@@ -2,15 +2,17 @@ import os
import gc import gc
import json import json
import torch import torch
import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from awq.quantize.quantizer import pseudo_quantize_tensor from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
...@@ -41,6 +43,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -41,6 +43,7 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale=True, mse_range=True, run_search=True, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"): calib_data="pileval"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search: if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
...@@ -51,7 +54,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -51,7 +54,7 @@ class BaseAWQForCausalLM(nn.Module):
self.is_quantized = True self.is_quantized = True
@staticmethod @staticmethod
def fuse_layers(model): def fuse_layers(model, quant_config):
pass pass
def _awq_quant(self): def _awq_quant(self):
...@@ -70,13 +73,18 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -70,13 +73,18 @@ class BaseAWQForCausalLM(nn.Module):
module.weight.data, scales, zeros = pseudo_quantize_tensor( module.weight.data, scales, zeros = pseudo_quantize_tensor(
module.weight.data, module.weight.data,
get_scale_zp=True, get_scale_zp=True,
**self.quant_config w_bit=self.quant_config["w_bit"],
q_group_size=self.quant_config["q_group_size"]
) )
if self.quant_config["version"] == 'GEMM':
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = WQLinear.from_linear( q_linear = q_linear_module.from_linear(
module, module,
self.quant_config['w_bit'], self.quant_config['w_bit'],
self.quant_config['q_group_size'], self.quant_config['q_group_size'],
...@@ -253,7 +261,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -253,7 +261,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None, def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True, fuse_layers=False): safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
...@@ -273,9 +281,12 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -273,9 +281,12 @@ class BaseAWQForCausalLM(nn.Module):
if os.path.exists(quant_config_path): if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file: with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read()) quant_config = json.loads(file.read())
if "version" not in quant_config.keys():
quant_config["version"] = version
else: else:
# Default config that works for most models # Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4} quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
...@@ -293,7 +304,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -293,7 +304,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized # Only need to replace layers if a model is AWQ quantized
if is_quantized: if is_quantized:
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config) self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights() model.tie_weights()
...@@ -313,7 +324,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -313,7 +324,7 @@ class BaseAWQForCausalLM(nn.Module):
) )
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model, quant_config)
else: else:
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
...@@ -333,7 +344,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -333,7 +344,7 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, quant_config): def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights # Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now." assert quant_config["zero_point"], "We only support zero_point quantization now."
...@@ -351,8 +362,17 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -351,8 +362,17 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
q_linear = WQLinear.from_linear( if version == 'GEMM':
module, quant_config['w_bit'], quant_config['q_group_size'], True) q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
True
)
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM): class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer" layer_type = "FalconDecoderLayer"
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config:dict):
fuser = FalconFuser(model)
fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: FalconForCausalLM): def get_model_layers(model: FalconForCausalLM):
return model.transformer.h return model.transformer.h
@staticmethod @staticmethod
def get_act_for_scaling(module: FalconDecoderLayer): def get_act_for_scaling(module: OldFalconDecoderLayer):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="mlp.act", scale_name="mlp.act",
...@@ -22,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -22,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: FalconDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# Falcon 7B (older architecture) # Falcon 7B (older architecture)
...@@ -57,3 +62,47 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM): ...@@ -57,3 +62,47 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
)) ))
return layers return layers
from awq.modules.fused.model import FalconModel
from awq.modules.fused.block import FalconDecoderLayer
class FalconFuser:
def __init__(self, model: FalconForCausalLM):
self.model = model
def fuse_transformer(self):
blocks = []
module: OldFalconDecoderLayer
for module in self.model.transformer.h:
if module.config.num_attention_heads == 71:
input_layernorm = module.input_layernorm
ln_attn = None
ln_mlp = None
new_decoder_arch = False
else:
input_layernorm = None
ln_attn = module.ln_attn
ln_mlp = module.ln_mlp
new_decoder_arch = True
blocks.append(FalconDecoderLayer(
hidden_size=module.config.hidden_size,
n_heads=module.config.num_attention_heads,
qkv_layer=module.self_attention.query_key_value,
o_proj=module.self_attention.dense,
mlp=module.mlp,
dev=next(iter(module.state_dict().values())).device,
max_seq_len=self.model.config.max_new_tokens,
input_layernorm=input_layernorm,
ln_attn=ln_attn,
ln_mlp=ln_mlp,
new_decoder_arch=new_decoder_arch
))
self.model.transformer = FalconModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.word_embeddings,
self.model.transformer.ln_f,
)
\ No newline at end of file
...@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: LlamaForCausalLM): def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
fuser = LlamaFuser(model) fuser = LlamaFuser(model, quant_config)
fuser.fuse_attention() fuser.fuse_attention()
fuser.fuse_rmsnorm() fuser.fuse_rmsnorm()
fuser.fuse_mlp() fuser.fuse_mlp()
...@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch import torch
from typing import List, Tuple from typing import List, Tuple, Union
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser: class LlamaFuser:
def __init__(self, model): def __init__(self, model, quant_config):
self.model = model self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, LlamaAttention]] = [ self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
...@@ -95,11 +96,10 @@ class LlamaFuser: ...@@ -95,11 +96,10 @@ class LlamaFuser:
def fuse_attention(self): def fuse_attention(self):
for name, module in self.attention_modules: for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module) qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantLlamaAttention( attn = QuantAttentionFused(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
module.num_key_value_heads,
qkv_layer, qkv_layer,
module.o_proj, module.o_proj,
next(iter(qkv_layer.state_dict().values())).device, next(iter(qkv_layer.state_dict().values())).device,
...@@ -108,12 +108,15 @@ class LlamaFuser: ...@@ -108,12 +108,15 @@ class LlamaFuser:
set_module_name(self.model, name, attn) set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: LlamaAttention): def _fuse_qkv(self, module: LlamaAttention):
# get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module if isinstance(q_proj, WQLinear_GEMV):
qkv_layer = WQLinear( q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit, q_proj.w_bit,
q_proj.group_size, q_proj.group_size,
q_proj.in_features, q_proj.in_features,
...@@ -122,10 +125,16 @@ class LlamaFuser: ...@@ -122,10 +125,16 @@ class LlamaFuser:
next(iter(module.state_dict().values())).device next(iter(module.state_dict().values())).device
) )
# replace buffers with real weights if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias qkv_layer.bias = bias
return qkv_layer return qkv_layer
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(model: MptForCausalLM): def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model) fuser = MptFuser(model)
fuser.fuse_mlp() fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: MptForCausalLM): def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks return model.transformer.blocks
@staticmethod @staticmethod
def get_act_for_scaling(module: MptBlock): def get_act_for_scaling(module: OldMptBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="ffn.act", scale_name="ffn.act",
...@@ -29,7 +29,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -29,7 +29,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -67,24 +67,38 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -67,24 +67,38 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple from typing import List, Tuple
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel
class MptFuser: class MptFuser:
def __init__(self, model): def __init__(self, model: MptForCausalLM):
self.model = model self.model = model
self.mlp_modules: List[Tuple[str, MptMLP]] = [ self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP) if 'mptblock' in module.__class__.__name__.lower()
] ]
def fuse_attention(self): def fuse_transformer(self):
pass blocks = []
def fuse_layernorm(self): module: OldMptBlock
pass for module in self.model.transformer.blocks:
blocks.append(MPTBlock(
self.model.config.d_model,
self.model.config.n_heads,
module.attn.Wqkv,
module.attn.out_proj,
module.ffn,
module.norm_1,
module.norm_2,
next(iter(module.state_dict().values())).device,
self.model.config.max_new_tokens
))
def fuse_mlp(self): self.model.transformer = MPTModel(
for name, module in self.mlp_modules: self.model.config.vocab_size,
mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj) blocks,
set_module_name(self.model, name, mlp) self.model.transformer.wte,
\ No newline at end of file self.model.transformer.norm_f,
)
\ No newline at end of file
from .fused_norm import *
from .fused_attn import *
from .fused_mlp import *
import torch.nn as nn
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
import os
import math
import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
def build_alibi_bias(
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
if full:
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1
)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
positions,
query,
key,
self.dim,
self.cos_sin_cache
)
return query, key
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = self.hidden_size // num_heads
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_local_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_reshape": (self.n_local_heads, self.head_dim // 8, 8),
"xq_view": (self.n_local_heads, self.head_dim),
"xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim)
}
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def forward(
self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
values_store = xv.transpose(2, 1)
keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4)
.contiguous()
)
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
keys = xk
values = xv
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
# xq = xq[:, 0, :, :]
# xk = xk[:, 0, :, :]
# xv = xv[:, 0, :, :]
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
past_key_value = (xk, xv) if use_cache else None
attention_weight = awq_inference_engine.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache_k, # key cache
self.cache_v, # value cache
None, # length per sample
self.alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
self.is_neox, # is neox
)
attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight)
if use_cache:
self.start_pos += seqlen
else:
self.start_pos = 0
return attn_output, attention_weight, past_key_value
\ No newline at end of file
import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1
self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h = hidden_states + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes
).to(dev)
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
else:
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch:
kv_heads = 8
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, n_heads+(kv_heads*2), max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, n_heads+(kv_heads*2), head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads+(kv_heads*2), head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :,0],
"xk_slice": lambda xqkv: xqkv[:, :, :,1],
"xv_slice": lambda xqkv: xqkv[:, :, :,2],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (1, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, 8, head_dim),
"single_xv_view": (1, 8, head_dim)
}
else:
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xk_reshape": (1, head_dim // 8, 8),
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
}
return self.attention_shapes
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h_attn = hidden_states + attn_output
if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out)
else:
h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp
return out, None, past_key_value
import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine import awq_inference_engine
import torch.nn.functional as F import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
return self.down_proj(self.act(x))
class QuantLlamaMLP(nn.Module): class QuantLlamaMLP(nn.Module):
...@@ -31,7 +9,7 @@ class QuantLlamaMLP(nn.Module): ...@@ -31,7 +9,7 @@ class QuantLlamaMLP(nn.Module):
self, self,
gate_proj, gate_proj,
down_proj, down_proj,
up_proj, up_proj
): ):
super().__init__() super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight) self.register_buffer('gate_proj_qweight', gate_proj.qweight)
...@@ -47,19 +25,32 @@ class QuantLlamaMLP(nn.Module): ...@@ -47,19 +25,32 @@ class QuantLlamaMLP(nn.Module):
self.w_bit = gate_proj.w_bit self.w_bit = gate_proj.w_bit
self.down_proj = down_proj self.down_proj = down_proj
def forward(self, x): if isinstance(down_proj, WQLinear_GEMV):
return self.down_proj(self.our_llama_mlp(x)) self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def our_llama_mlp(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size, ) out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemm_forward_cuda( gate_output = self.linear(
x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, 8 x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.group_size,
) )
gate_output = F.silu(gate_output) up_output = self.linear(
up_output = awq_inference_engine.gemm_forward_cuda( x,
x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8 self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size,
) )
c = gate_output * up_output x = F.silu(gate_output) * up_output
c = c.reshape(out_shape) x = x.reshape(out_shape)
return c x = self.down_proj(x)
return x
\ No newline at end of file
import torch
import torch.nn as nn
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f):
super().__init__()
self.vocab_size = vocab_size
self.wte = wte
self.blocks: list[MPTBlock] = nn.ModuleList(blocks)
self.norm_f = norm_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
_bsz, seqlen = input_ids.shape
h = self.wte(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
super().__init__()
self.vocab_size = vocab_size
self.word_embeddings = word_embeddings
self.blocks: list[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
# NOTE: falcon input ids contain full context
# after context is processed, slice to latest token
if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1:
input_ids = input_ids[:, self.blocks[0].attn.start_pos:]
_bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaRotaryEmbedding
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# [max_position, rot_dim]
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding(
positions,
query,
key,
self.dim,
self.cos_sin_cache,
True # is_neox
)
return query, key
class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
num_kv_heads,
qkv_proj,
o_proj,
dev,
max_new_tokens,
use_hf_rotary=False
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.use_hf_rotary = use_hf_rotary
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
if use_hf_rotary:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_new_tokens, device=dev)
else:
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
if self.use_hf_rotary:
# get qkv
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
query, key, value = torch.split(qkv_states, 1, dim=2)
del qkv_states
# reshape for hf rotary
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
else:
# get qkv
query, key, value = qkv_states.chunk(chunks=3, dim=-1)
del qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size, query_len, _ = query.shape
query = query.view(query_len*query_batch_size, self.num_heads * self.head_dim)
# [num_tokens, num_kv_heads * head_size]
key_batch_size, key_len, _ = key.shape
key = key.view(key_len*key_batch_size, self.num_kv_heads * self.head_dim)
# [num_tokens]
positions = position_ids.view(-1).to(query.device)
query, key = self.rotary_emb(query, key, positions)
query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
value = value.to(key.device)
if past_key_value is not None:
# reuse k, v, self_attention
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
if use_cache:
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
past_key_value = (key, value) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
del query, key, value
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
...@@ -4,17 +4,24 @@ import torch.nn as nn ...@@ -4,17 +4,24 @@ import torch.nn as nn
import awq_inference_engine # with CUDA kernels import awq_inference_engine # with CUDA kernels
class ScaledActivation(nn.Module): def make_divisible(c, divisor):
def __init__(self, module, scales): return (c + divisor - 1) // divisor
super().__init__()
self.act = module def calculate_zeros_width(in_features, group_size=128, pack_num=8):
self.scales = nn.Parameter(scales.data) if group_size >= 128:
size_multiplier = 1
def forward(self, x): elif group_size == 64:
return self.act(x) / self.scales.view(1, 1, -1).to(x.device) size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
class WQLinear(nn.Module): class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__() super().__init__()
...@@ -25,6 +32,7 @@ class WQLinear(nn.Module): ...@@ -25,6 +32,7 @@ class WQLinear(nn.Module):
self.out_features = out_features self.out_features = out_features
self.w_bit = w_bit self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features self.group_size = group_size if group_size != -1 else in_features
# quick sanity check (make sure aligment) # quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0 assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0 assert out_features % (32 // self.w_bit) == 0
...@@ -97,3 +105,100 @@ class WQLinear(nn.Module): ...@@ -97,3 +105,100 @@ class WQLinear(nn.Module):
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
) )
class WQLinear_GEMV(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = (32 // self.w_bit)
self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev))
if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
else:
self.bias = None
@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num),
dtype=torch.float16,
device=scales.device
)
qscales[:, :scales.shape[1]] = scales
awq_linear.scales = qscales
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros(
(zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
dtype=torch.int32,
device=zeros.device,
)
for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
if col * pack_num + order_map[i] >= zeros.shape[1]:
continue
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
\ No newline at end of file
...@@ -43,7 +43,7 @@ def auto_clip_layer(w, ...@@ -43,7 +43,7 @@ def auto_clip_layer(w,
max_val = org_max_val * (1 - i_s / n_grid) max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val) cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, **quant_config) q_w = pseudo_quantize_tensor(cur_w, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"])
cur_out = (input_feat * q_w).sum(dim=-1) cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1 # co, 1, n_group, 1
......
...@@ -7,7 +7,7 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu ...@@ -7,7 +7,7 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"] __all__ = ["auto_scale_block", "apply_scale"]
...@@ -98,7 +98,7 @@ def auto_scale_block(awq_model, ...@@ -98,7 +98,7 @@ def auto_scale_block(awq_model,
from .quantizer import pseudo_quantize_tensor from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function # firstly, get the weight quantize function
if quant_config['w_bit'] is not None: if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor(p, **quant_config).detach() def w_quantize_func(p): return pseudo_quantize_tensor(p, w_bit=quant_config["w_bit"], q_group_size=quant_config["q_group_size"]).detach()
else: else:
def w_quantize_func(p): return p def w_quantize_func(p): return p
......
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace fastertransformer {
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t; t.x = x; t.y = y; return t;
}
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace fastertransformer
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template<typename T, typename KERNEL_PARAMS_TYPE>
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head) {
case 32:
mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 48:
mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 64:
mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 80:
mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 96:
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 112:
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 160:
mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 192:
mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 224:
mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 256:
mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
default:
assert(false);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment