Unverified Commit fd6482ad authored by Xu Kai's avatar Xu Kai Committed by GitHub
Browse files

[inference] Refactor inference architecture (#5057)



* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------
Co-authored-by: default avatarBin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
parent bc09b95f
# 🚀 Colossal-Inference
## Table of contents
## Table of Contents
- [💡 Introduction](#introduction)
- [🔗 Design](#design)
- [🔨 Usage](#usage)
- [Quick start](#quick-start)
- [Example](#example)
- [📊 Performance](#performance)
## Introduction
......@@ -8,22 +16,23 @@
## Design
Colossal Inference is composed of two main components:
Colossal Inference is composed of three main components:
1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
1. `HybridEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel, pipline parallel) inference:
2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
## Pipeline of inference:
## Architecture of inference:
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png)
![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png)
## Roadmap of our implementation
......@@ -35,12 +44,14 @@ In this section we discuss how the colossal inference works and integrates with
- [x] context forward
- [x] token forward
- [x] support flash-decoding
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Support all models
- [x] Llama
- [x] Llama-2
- [x] Bloom
- [x] Chatglm2
- [x] Quantization
- [x] GPTQ
- [x] SmoothQuant
- [ ] Benchmarking for all models
## Get started
......@@ -53,27 +64,19 @@ pip install -e .
### Requirements
dependencies
Install dependencies.
```bash
pytorch= 1.13.1 (gpu)
cuda>= 11.6
transformers= 4.30.2
triton
# for install flash-attention
flash-attention
# install lightllm since we depend on lightllm triton kernels
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e .
# install flash-attention
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .
pip install -r requirements/requirements-infer.txt
# if you want use smoothquant quantization, please install torch-int
git clone --recurse-submodules https://github.com/Guangxuan-Xiao/torch-int.git
cd torch-int
git checkout 65266db1eadba5ca78941b789803929e6e6c6856
pip install -r requirements.txt
source environment.sh
bash build_cutlass.sh
python setup.py install
```
### Docker
......@@ -89,26 +92,60 @@ docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcai
cd /path/to/CollossalAI
pip install -e .
# install lightllm
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e .
# install flash-attention
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .
```
### Dive into fast-inference!
## Usage
### Quick start
example files are in
```bash
cd colossalai.examples
python xx
cd ColossalAI/examples
python hybrid_llama.py --path /path/to/model --tp_size 2 --pp_size 2 --batch_size 4 --max_input_size 32 --max_out_len 16 --micro_batch_size 2
```
### Example
```python
# import module
from colossalai.inference import CaiInferEngine
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
#launch distributed environment
colossalai.launch_from_torch(config={})
# load original model and tokenizer
model = LlamaForCausalLM.from_pretrained("/path/to/model")
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
# generate token ids
input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
data = tokenizer(input, return_tensors='pt')
# set parallel parameters
tp_size=2
pp_size=2
max_output_len=32
micro_batch_size=1
# initial inference engine
engine = CaiInferEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_output_len=max_output_len,
micro_batch_size=micro_batch_size,
)
# inference
output = engine.generate(data)
# get results
if dist.get_rank() == 0:
assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
```
## Performance
......@@ -123,7 +160,9 @@ For various models, experiments were conducted using multiple batch sizes under
Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned.
#### Llama
### Tensor Parallelism Inference
##### Llama
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
......@@ -132,7 +171,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png)
### Bloom
#### Bloom
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
......@@ -141,4 +180,50 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc
![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png)
### Pipline Parallelism Inference
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. We set input length=1024, output length=128.
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: |
| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM |
| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM |
![ppllama7b](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama7b.png)
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
![ppllama13](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a10-llama13b.png)
#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
![ppllama7b_a800](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/pp-a800-llama7b.png)
### Quantization LLama
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| auto-gptq | 199.20 | 232.56 | 253.26 |
| smooth-quant | 142.28 | 222.96 | 300.59 |
| colossal-gptq | 231.98 | 388.87 | 573.03 |
![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-quant.png)
The results of more models are coming soon!
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy
from .engine import CaiInferEngine
from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
from typing import Union
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import logging
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from ..kv_cache import MemoryManager
from .microbatch_manager import MicroBatchManager
from .policies import model_policy_map
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
"BloomForCausalLM",
"LlamaGPTQForCausalLM",
"SmoothLlamaForCausalLM",
"ChatGLMForConditionalGeneration",
]
class CaiInferEngine:
"""
CaiInferEngine is a class that handles the pipeline parallel inference.
Args:
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
Example:
```python
from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
colossalai.launch_from_torch(config={})
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])
```
"""
def __init__(
self,
tp_size: int = 1,
pp_size: int = 1,
dtype: str = "fp16",
model: nn.Module = None,
model_policy: Policy = None,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
max_input_len: int = 32,
max_output_len: int = 32,
quant: str = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
if quant == "gptq":
from ..quant.gptq import GPTQManager
self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len)
model = model.model
elif quant == "smoothquant":
model = model.model
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert model, "Model should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
self.pp_size = pp_size
self.tp_size = tp_size
self.quant = quant
logger = logging.get_logger(__name__)
if quant == "smoothquant" and dtype != "fp32":
dtype = "fp32"
logger.warning_once("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32")
if dtype == "fp16":
self.dtype = torch.float16
model.half()
elif dtype == "bf16":
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
if model_policy is None:
model_policy = model_policy_map[model.config.model_type]()
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)
def generate(self, input_list: Union[BatchEncoding, dict]):
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
Returns:
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
else:
return out
def _shardformer(self, model, model_policy, stage_manager, tp_group):
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"quant": self.quant},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_key_value_heads // self.tp_size
num_hidden_layers = (
model.config.num_hidden_layers
if hasattr(model.config, "num_hidden_layers")
else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
elif model.config.model_type == "bloom":
head_dim = model.config.hidden_size // model.config.n_head
head_num = model.config.n_head // self.tp_size
num_hidden_layers = model.config.n_layer
layer_num = num_hidden_layers // self.pp_size
elif model.config.model_type == "chatglm":
head_dim = model.config.hidden_size // model.config.num_attention_heads
if model.config.multi_query_attention:
head_num = model.config.multi_query_group_num // self.tp_size
else:
head_num = model.config.num_attention_heads // self.tp_size
num_hidden_layers = model.config.num_layers
layer_num = num_hidden_layers // self.pp_size
else:
raise NotImplementedError("Only support llama, bloom and chatglm model.")
if self.quant == "smoothquant":
cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
else:
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager
from enum import Enum
from typing import Dict
import torch
from ..kv_cache import BatchInferState, MemoryManager
__all__ = "MicroBatchManager"
class Status(Enum):
PREFILL = 1
GENERATE = 2
DONE = 3
COOLDOWN = 4
class MicroBatchDescription:
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(
self,
inputs_dict: Dict[str, torch.Tensor],
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
) -> None:
self.mb_length = inputs_dict["input_ids"].shape[-1]
self.target_length = self.mb_length + max_output_len
self.infer_state = BatchInferState.init_from_batch(
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
)
# print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
def update(self, *args, **kwargs):
pass
@property
def state(self):
"""
Return the state of current micro batch, when current length is equal to target length,
the state is DONE, otherwise GENERATE
"""
# TODO: add the condition for early stopping
if self.cur_length == self.target_length:
return Status.DONE
elif self.cur_length == self.target_length - 1:
return Status.COOLDOWN
else:
return Status.GENERATE
@property
def cur_length(self):
"""
Return the current sequnence length of micro batch
"""
class HeadMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(
self,
inputs_dict: Dict[str, torch.Tensor],
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
) -> None:
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
assert inputs_dict is not None
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"]
self.attn_mask = inputs_dict["attention_mask"]
self.new_tokens = None
def update(self, new_token: torch.Tensor = None):
if new_token is not None:
self._update_newtokens(new_token)
if self.state is not Status.DONE and new_token is not None:
self._update_attnmask()
def _update_newtokens(self, new_token: torch.Tensor):
if self.new_tokens is None:
self.new_tokens = new_token
else:
self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)
def _update_attnmask(self):
self.attn_mask = torch.cat(
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
)
@property
def cur_length(self):
"""
When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token
"""
if self.new_tokens is None:
return self.mb_length
else:
return self.mb_length + len(self.new_tokens[0])
class BodyMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
"""
def __init__(
self,
inputs_dict: Dict[str, torch.Tensor],
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
) -> None:
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
@property
def cur_length(self):
"""
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
"""
return self.infer_state.seq_len.max().item()
class MicroBatchManager:
"""
MicroBatchManager is a class that manages the micro batch.
Args:
stage (int): stage id of current stage.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
"""
def __init__(
self,
stage: int,
micro_batch_size: int,
micro_batch_buffer_size: int,
max_input_len: int,
max_output_len: int,
cache_manager_list: MemoryManager,
):
self.stage = stage
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
)
def step(self, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
self.cur_descrption.update(new_token)
return self.cur_state
def export_new_tokens(self):
new_tokens_list = []
for i in self.mb_descrption_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list
def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0:
return False
for mb in self.mb_descrption_buffer.values():
if mb.state != Status.DONE:
return False
return True
def clear(self):
self.mb_descrption_buffer.clear()
for cache in self.cache_manager_list:
cache.free_all()
def next(self):
self.idx = (self.idx + 1) % self.buffer_size
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)
@property
def cur_descrption(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx)
@property
def cur_infer_state(self):
if self.cur_descrption is None:
return None
return self.cur_descrption.infer_state
@property
def cur_state(self):
"""
Return the state of current micro batch, when current descrption is None, the state is PREFILL
"""
if self.cur_descrption is None:
return Status.PREFILL
return self.cur_descrption.state
from .bloom import BloomInferenceForwards
from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards
__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards", "ChatGLM2InferenceForwards"]
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.nn import functional as F
from transformers.models.bloom.modeling_bloom import (
BaseModelOutputWithPastAndCrossAttentions,
BloomAttention,
BloomBlock,
BloomForCausalLM,
BloomModel,
)
from transformers.utils import logging
from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
from colossalai.pipeline.stage_manager import PipelineStageManager
try:
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_bloom_context_attention_fwd,
)
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False
def generate_alibi(n_head, dtype=torch.float16):
"""
This method is adapted from `_generate_alibi` function
in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py`
of the ModelTC/lightllm GitHub repository.
This method is originally the `build_alibi_tensor` function
in `transformers/models/bloom/modeling_bloom.py`
of the huggingface/transformers GitHub repository.
"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined
slopes = get_slopes(n_head)
return torch.tensor(slopes, dtype=dtype)
class BloomInferenceForwards:
"""
This class serves a micro library for bloom inference forwards.
We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
as well as prepare_inputs_for_generation method for BloomForCausalLM.
For future improvement, we might want to skip replacing methods for BloomForCausalLM,
and call BloomModel.forward iteratively in TpInferEngine
"""
@staticmethod
def bloom_for_causal_lm_forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = False,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
tp_group: Optional[dist.ProcessGroup] = None,
**deprecated_arguments,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# If is first stage and hidden_states is not None, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
outputs = BloomInferenceForwards.bloom_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
tp_group=tp_group,
)
return outputs
@staticmethod
def bloom_model_forward(
self: BloomModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
tp_group: Optional[dist.ProcessGroup] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
# add warnings here
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# first stage
if stage_manager.is_first_stage():
# check inputs and inputs embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
# other stage
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
if seq_length != 1:
# prefill stage
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
BatchInferState.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if attention_mask is None:
attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
# NOTE revise: we might want to store a single 1D alibi(length is #heads) in model,
# or store to BatchInferState to prevent re-calculating
# When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
curr_tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
alibi = (
generate_alibi(self.num_heads * tp_size)
.contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
.cuda()
)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
block = self.h[idx]
outputs = block(
hidden_states,
layer_past=past_key_value,
attention_mask=causal_mask,
head_mask=head_mask[idx],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states = outputs[0]
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
hidden_states = self.ln_f(hidden_states)
# update indices
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
# always return dict for imediate stage
return {"hidden_states": hidden_states}
@staticmethod
def bloom_block_forward(
self: BloomBlock,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
infer_state=infer_state,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
@staticmethod
def bloom_attention_forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
if infer_state.is_context_stage:
# context process
max_input_len = q_length
b_start_loc = infer_state.start_loc
b_seq_len = infer_state.seq_len[:batch_size]
q = query_layer.reshape(-1, H, D_HEAD)
copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
# output = self.output[:batch_size*q_length, :, :]
output = torch.empty_like(q)
if HAS_LIGHTLLM_KERNEL:
lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
else:
bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
else:
# query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
assert q_length == 1, "for non-context process, we only support q_length == 1"
q = query_layer.reshape(-1, H, D_HEAD)
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(k)
cache_v.copy_(v)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
b_start_loc = infer_state.start_loc
b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len
output = torch.empty_like(q)
token_attention_fwd(
q,
mem_manager.key_buffer[layer_id],
mem_manager.value_buffer[layer_id],
output,
b_loc,
b_start_loc,
b_seq_len,
infer_state.max_len_in_batch,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
# NOTE: always set present as none for now, instead of returning past key value to the next decoding,
# we create the past key value pair from the cache manager
present = None
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# dropout is not required here during inference
output_tensor = residual + output_tensor
outputs = (output_tensor, present)
assert output_attentions is False, "we do not support output_attentions at this time"
return outputs
from typing import List, Optional, Tuple
import torch
from transformers.utils import logging
from colossalai.inference.kv_cache import BatchInferState
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
split_tensor_along_last_dim,
)
from ._utils import copy_kv_to_mem_cache
try:
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
def get_masks(self, input_ids, past_length, padding_mask=None):
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
full_attention_mask.tril_()
if past_length:
full_attention_mask = torch.cat(
(
torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
full_attention_mask,
),
dim=-1,
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
def get_position_ids(batch_size, seq_length, device):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
class ChatGLM2InferenceForwards:
"""
This class holds forwards for Chatglm2 inference.
We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention.
"""
@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = True,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
logger = logging.get_logger(__name__)
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# If is first stage and hidden_states is not None, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
if return_last_logit:
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
return {"logits": lm_logits}
outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
return outputs
@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if position_ids is None:
position_ids = get_position_ids(batch_size, seq_length, input_ids.device)
hidden_states = inputs_embeds
else:
assert hidden_states is not None, "hidden_states should not be None in non-first stage"
seq_length, batch_size, _ = hidden_states.shape
if position_ids is None:
position_ids = get_position_ids(batch_size, seq_length, hidden_states.device)
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
seq_length_with_past = seq_length + past_key_values_length
# prefill stage at first
if seq_length != 1:
infer_state.is_context_stage = True
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
# related to rotary embedding
if infer_state.is_context_stage:
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = get_masks(
self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask
)
# Run encoder.
hidden_states = self.encoder(
hidden_states,
full_attention_mask,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
infer_state=infer_state,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=shard_config,
)
# update indices
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
return {"hidden_states": hidden_states}
@staticmethod
def chatglm_encoder_forward(
self: GLMTransformer,
hidden_states,
attention_mask,
kv_caches=None,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
hidden_states = hidden_states.transpose(0, 1).contiguous()
infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1]
if kv_caches is None:
kv_caches = tuple([None] * (end_idx - start_idx + 1))
for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches):
layer = self.layers[idx]
layer_ret = layer(
hidden_states,
attention_mask,
kv_cache=kv_cache,
use_cache=use_cache,
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states, _ = layer_ret
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1):
# Final layer norm.
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
@staticmethod
def chatglm_glmblock_forward(
self: GLMBlock,
hidden_states,
attention_mask,
kv_cache=None,
use_cache=True,
infer_state: Optional[BatchInferState] = None,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
kv_cache=kv_cache,
use_cache=use_cache,
infer_state=infer_state,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
layernorm_input = residual + layernorm_input
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
output = residual + output
return output, kv_cache
@staticmethod
def chatglm_flash_attn_kvcache_forward(
self: SelfAttention,
hidden_states,
attention_mask,
kv_cache=None,
use_cache=True,
infer_state: Optional[BatchInferState] = None,
):
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
batch_size = hidden_states.shape[0]
hidden_size = hidden_states.shape[-1]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
key_layer = key_layer.view(
key_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
cos, sin = infer_state.position_cos, infer_state.position_sin
chatglm2_rotary_emb_fwd(
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
)
if self.multi_query_attention:
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
)
else:
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
)
# reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128
query_layer = query_layer.reshape(
-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
)
key_layer = key_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
value_layer = value_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
if infer_state.is_context_stage:
# first token generation:
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_layer,
value_layer,
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
# NOTE: no bug in context attn fwd (del it )
lightllm_llama2_context_attention_fwd(
query_layer,
key_layer,
value_layer,
attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_layer)
cache_v.copy_(value_layer)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_layer,
value_layer,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# second token and follows
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
: infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
: infer_state.decode_mem_end, :, :
]
# ==================================
# core attention computation is replaced by triton kernel
# ==================================
Llama2TokenAttentionForwards.token_attn(
query_layer,
cache_k,
cache_v,
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
# =================
# Output:[b,sq, h]
# =================
output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
return output, kv_cache
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
import math
from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers.utils import logging
from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import copy_kv_to_mem_cache
try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def llama_triton_context_attention(
query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
):
if num_key_value_groups == 1:
if HAS_LIGHTLLM_KERNEL is False:
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
"""
@staticmethod
def llama_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# If is first stage and hidden_states is None, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaInferenceForwards.llama_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
return outputs
@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if stage_manager is None or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
assert stage_manager is not None
assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}"
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
# NOTE assume prefill stage
# allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.repeat(batch_size, 1)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if infer_state.is_context_stage:
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
# decoder layers
infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
decoder_layer = self.layers[idx]
# NOTE: modify here for passing args to decoder layer
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0]
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
hidden_states = self.norm(hidden_states)
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
return {"hidden_states": hidden_states}
@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
infer_state=infer_state,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
@staticmethod
def llama_flash_attn_kvcache_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
# NOTE might think about better way to handle transposed k and v
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
cos, sin = infer_state.position_cos, infer_state.position_sin
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
if infer_state.is_context_stage:
# first token generation
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
llama_triton_context_attention(
query_states,
key_states,
value_states,
attn_output,
infer_state,
num_key_value_groups=self.num_key_value_groups,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
attn_output = flash_attn_with_kvcache(
q=query_states,
k_cache=copy_cache_k,
v_cache=copy_cache_v,
softmax_scale=1 / math.sqrt(self.head_dim),
causal=True,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# return past_key_value as None
return attn_output, None, None
from .bloom import BloomModelInferPolicy
from .chatglm2 import ChatGLM2InferPolicy
from .llama import LlamaModelInferPolicy
model_policy_map = {
"llama": LlamaModelInferPolicy,
"bloom": BloomModelInferPolicy,
"chatglm": ChatGLM2InferPolicy,
}
__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy", "model_polic_map"]
from functools import partial
from typing import List
import torch
from torch.nn import LayerNorm, Module
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
from ..modeling.bloom import BloomInferenceForwards
try:
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False
def get_triton_layernorm_forward():
if HAS_TRITON_NORM:
def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
return _triton_layernorm_forward
else:
return None
class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 3},
),
SubModuleReplacementDescription(
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
],
)
# NOTE set inference mode to shard config
self.shard_config._infer()
# set as default, in inference we also use pipeline style forward, just setting stage as 1
self.set_pipeline_forward(
model_cls=BloomForCausalLM,
new_forward=partial(
BloomInferenceForwards.bloom_for_causal_lm_forward,
tp_group=self.shard_config.tensor_parallel_process_group,
),
policy=policy,
)
method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomAttention
)
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
method_replacement = {"forward": partial(infer_method)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LayerNorm
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "BloomModel":
module = self.model
else:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
from typing import List
import torch.nn as nn
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
GLMTransformer,
SelfAttention,
)
# import colossalai
from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.chatglm2 import ChatGLM2InferenceForwards
try:
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
class ChatGLM2InferPolicy(ChatGLMModelPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
self.shard_config._infer()
model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
method_replacement = {"forward": model_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
method_replacement = {"forward": encoder_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=GLMTransformer
)
encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
method_replacement = {"forward": encoder_layer_infer_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
method_replacement = {"forward": attn_infer_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)
if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
# for rmsnorm and others, we need to check the shape
self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration,
new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[nn.Module]:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
held_layers.append(module.output_layer)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
held_layers.append(module.encoder.final_layernorm)
# rotary_pos_emb is needed for all stages
held_layers.append(module.rotary_pos_emb)
return held_layers
def postprocess(self):
init_to_get_rotary(self.model.transformer)
return self.model
from functools import partial
from typing import List
import torch
from torch.nn import Module
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards
try:
from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else:
return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
],
)
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
RowW8A8B8O8Linear,
RowW8A8BFP32O32LinearSiLU,
RowW8A8BFP32OFP32Linear,
)
policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=RowW8A8B8O8Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=RowW8A8BFP32O32LinearSiLU,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=RowW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=ColW8A8BFP32OFP32Linear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
# set as default, in inference we also use pipeline style forward, just setting stage as 1
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
return policy
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "LlamaModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager
"""
Refered/Modified from lightllm/common/mem_manager.py
of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging
class MemoryManager:
r"""
Manage token block indexes and allocate physical memory for key and value cache
Args:
size: maximum token number used as the size of key and value buffer
dtype: data type of cached key and value
head_num: number of heads the memory manager is responsible for
head_dim: embedded size per head
layer_num: the number of layers in the model
device: device used to store the key and value cache
"""
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device("cuda"),
):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.max_len_in_batch = 0
self._init_mem_states(size, device)
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device):
"""Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
"""Initialize key buffer and value buffer on specified device"""
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
self.value_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
@torch.no_grad()
def alloc(self, required_size):
"""allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
select_index = self.indexes[select_index]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
return select_index
@torch.no_grad()
def alloc_contiguous(self, required_size):
"""allocate contiguous space of required_size"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
loc_sums = (
self.mem_cum_sum[required_size - 1 :]
- self.mem_cum_sum[0 : sum_size - required_size + 1]
+ self.mem_state[0 : sum_size - required_size + 1]
)
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
self.logger.info(
f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
)
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
end = start + required_size
return select_index, start, end
@torch.no_grad()
def free(self, free_index):
"""free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1
@torch.no_grad()
def free_all(self):
"""free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.max_len_in_batch = 0
# self.logger.info("freed all space of memory manager")
from .smoothquant.models.llama import SmoothLlamaForCausalLM
......@@ -2,3 +2,4 @@ from .cai_gptq import HAS_AUTO_GPTQ
if HAS_AUTO_GPTQ:
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
from .gptq_manager import GPTQManager
import torch
class GPTQManager:
def __init__(self, quant_config, max_input_len: int = 1):
self.max_dq_buffer_size = 1
self.max_inner_outer_dim = 1
self.bits = quant_config.bits
self.use_act_order = quant_config.desc_act
self.max_input_len = 1
self.gptq_temp_state_buffer = None
self.gptq_temp_dq_buffer = None
self.quant_config = quant_config
def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
from .cai_gptq import CaiQuantLinear
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
for name, submodule in model.named_modules():
if isinstance(submodule, CaiQuantLinear):
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
if self.use_act_order:
self.max_inner_outer_dim = max(
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return
max_input_len = 1
if self.use_act_order:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros(
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
self.gptq_temp_dq_buffer = torch.zeros(
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
)
gptq_cuda.prepare_buffers(
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment