Unverified Commit fd6482ad authored by Xu Kai's avatar Xu Kai Committed by GitHub
Browse files

[inference] Refactor inference architecture (#5057)



* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------
Co-authored-by: default avatarBin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
parent bc09b95f
...@@ -4,9 +4,7 @@ try: ...@@ -4,9 +4,7 @@ try:
HAS_TORCH_INT = True HAS_TORCH_INT = True
except ImportError: except ImportError:
HAS_TORCH_INT = False HAS_TORCH_INT = False
raise ImportError( print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
"Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
)
if HAS_TORCH_INT: if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
...@@ -9,7 +9,6 @@ from functools import partial ...@@ -9,7 +9,6 @@ from functools import partial
from os.path import isdir, isfile, join from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import accelerate
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -21,8 +20,16 @@ from transformers.modeling_utils import no_init_weights ...@@ -21,8 +20,16 @@ from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers from transformers.utils.generic import ContextManagers
from transformers.utils.hub import PushToHubMixin, cached_file from transformers.utils.hub import PushToHubMixin, cached_file
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.kv_cache.batch_infer_state import BatchInferState, MemoryManager
from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
try:
import accelerate
HAS_ACCELERATE = True
except ImportError:
HAS_ACCELERATE = False
print("accelerate is not installed.")
SUPPORTED_MODELS = ["llama"] SUPPORTED_MODELS = ["llama"]
......
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py # modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
import torch import torch
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax try:
from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
from torch_int.functional.quantization import quantize_per_tensor_absmax
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
try: try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load() smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True HAS_SMOOTHQUANT_CUDA = True
except ImportError: except:
HAS_SMOOTHQUANT_CUDA = False HAS_SMOOTHQUANT_CUDA = False
raise ImportError("CUDA smoothquant linear is not installed") print("CUDA smoothquant linear is not installed")
class W8A8BFP32O32LinearSiLU(torch.nn.Module): class W8A8BFP32O32LinearSiLU(torch.nn.Module):
...@@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module): ...@@ -138,21 +146,23 @@ class W8A8BFP32OFP32Linear(torch.nn.Module):
) )
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
) )
self.register_buffer("a", torch.tensor(alpha)) self.register_buffer("a", torch.tensor(alpha))
def _apply(self, fn): def _apply(self, fn):
# prevent the bias from being converted to half # prevent the bias from being converted to half
super()._apply(fn) super()._apply(fn)
self.bias = self.bias.to(torch.float32) if self.bias is not None:
self.bias = self.bias.to(torch.float32)
return self return self
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
super().to(*args, **kwargs) super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs) self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs) if self.bias is not None:
self.bias = self.bias.to(torch.float32) self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self return self
@torch.no_grad() @torch.no_grad()
......
...@@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union ...@@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
...@@ -18,12 +17,11 @@ from transformers.models.llama.modeling_llama import ( ...@@ -18,12 +17,11 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer, LlamaDecoderLayer,
LlamaMLP, LlamaMLP,
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
repeat_kv,
rotate_half, rotate_half,
) )
from transformers.utils import add_start_docstrings_to_model_forward from transformers.utils import add_start_docstrings_to_model_forward
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
from colossalai.kernel.triton import ( from colossalai.kernel.triton import (
copy_kv_cache_to_dest, copy_kv_cache_to_dest,
int8_rotary_embedding_fwd, int8_rotary_embedding_fwd,
...@@ -31,10 +29,31 @@ from colossalai.kernel.triton import ( ...@@ -31,10 +29,31 @@ from colossalai.kernel.triton import (
smooth_token_attention_fwd, smooth_token_attention_fwd,
) )
try:
from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
from .base_model import BaseSmoothForCausalLM from .base_model import BaseSmoothForCausalLM
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class LLamaSmoothquantAttention(nn.Module): class LLamaSmoothquantAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module): ...@@ -116,7 +135,6 @@ class LLamaSmoothquantAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
...@@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module): ...@@ -131,8 +149,7 @@ class LLamaSmoothquantAttention(nn.Module):
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
cos = rotary_emb[0] cos, sin = infer_state.position_cos, infer_state.position_sin
sin = rotary_emb[1]
int8_rotary_embedding_fwd( int8_rotary_embedding_fwd(
query_states.view(-1, self.num_heads, self.head_dim), query_states.view(-1, self.num_heads, self.head_dim),
...@@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module): ...@@ -348,7 +365,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
...@@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module): ...@@ -378,7 +394,6 @@ class LlamaSmoothquantDecoderLayer(nn.Module):
# Self Attention # Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
rotary_emb=rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_value,
...@@ -650,15 +665,15 @@ def llama_model_forward( ...@@ -650,15 +665,15 @@ def llama_model_forward(
raise NotImplementedError("not implement gradient_checkpointing and training options ") raise NotImplementedError("not implement gradient_checkpointing and training options ")
if past_key_values_length == 0: if past_key_values_length == 0:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1 position_ids.view(-1).shape[0], -1
) )
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1 position_ids.view(-1).shape[0], -1
) )
else: else:
position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -673,7 +688,6 @@ def llama_model_forward( ...@@ -673,7 +688,6 @@ def llama_model_forward(
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
rotary_emb=(position_cos, position_sin),
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_value, past_key_value=past_key_value,
......
from typing import List, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import ParallelModule
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
if smooth_linear.bias is not None:
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
smooth_split_out_features = para_linear.out_features // split_num
for i in range(split_num):
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
]
if para_linear.bias is not None:
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
]
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
smooth_split_in_features = para_linear.in_features // split_num
for i in range(split_num):
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
]
if smooth_linear.bias is not None:
para_linear.bias.copy_(smooth_linear.bias)
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
linear_1d.b = module.b.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = torch.tensor(module.a)
linear_1d.b = torch.tensor(module.b)
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
if linear_1d.bias is not None:
linear_1d.bias = linear_1d.bias // tp_size
return linear_1d
@torch.no_grad()
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
return output
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
out_features = module.out_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if out_features < tp_size:
return module
if out_features % tp_size != 0:
raise ValueError(
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
return linear_1d
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__(in_features, out_features, alpha, beta)
self.process_group = None
self.tp_size = 1
self.tp_rank = 0
@staticmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
tp_size = dist.get_world_size(process_group)
tp_rank = dist.get_rank(process_group)
if in_features < tp_size:
return module
if in_features % tp_size != 0:
raise ValueError(
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
)
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
linear_1d.tp_size = tp_size
linear_1d.tp_rank = tp_rank
linear_1d.process_group = process_group
linear_1d.a = module.a.clone().detach()
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
if linear_1d.bias is not None:
linear_1d.bias = linear_1d.bias / tp_size
return linear_1d
@torch.no_grad()
def forward(self, x):
output = super().forward(x)
if self.tp_size > 1:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
return output
# 🚀 Colossal-Inference
## Table of contents
## Introduction
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
## Design
Colossal Inference is composed of two main components:
1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
## Pipeline of inference:
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png)
## Roadmap of our implementation
- [x] Design cache manager and batch infer state
- [x] Design TpInference engine to integrates with `Shardformer`
- [x] Register corresponding high-performance `kernel` and `ops`
- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
- [x] policy
- [x] context forward
- [x] token forward
- [x] support flash-decoding
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Llama
- [x] Llama-2
- [x] Bloom
- [x] Chatglm2
- [ ] Benchmarking for all models
## Get started
### Installation
```bash
pip install -e .
```
### Requirements
dependencies
```bash
pytorch= 1.13.1 (gpu)
cuda>= 11.6
transformers= 4.30.2
triton
# for install flash-attention
flash-attention
# install lightllm since we depend on lightllm triton kernels
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e .
# also, install xformers from source:
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```
### Docker
You can use docker run to use docker container to set-up environment
```
# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
# enter into docker container
cd /path/to/CollossalAI
pip install -e .
# install lightllm
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip3 install -e .
# install xformers from source
pip install ninja
# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
```
### Dive into fast-inference!
example files are in
```bash
cd colossalai.examples
python xx
```
## Performance
### environment:
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.
For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):
### Single GPU Performance:
Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned.
#### Llama
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
| colossal-inference | 326.4 | 582.72 | 816.64 |
![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png)
### Bloom
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
| colossal-inference | 323.28 | 538.52 | 611.64 |
![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png)
The results of more models are coming soon!
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
from .engine import CaiInferEngine
__all__ = ["CaiInferEngine"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment