Unverified Commit fd6482ad authored by Xu Kai's avatar Xu Kai Committed by GitHub
Browse files

[inference] Refactor inference architecture (#5057)



* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------
Co-authored-by: default avatarBin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
parent bc09b95f
......@@ -7,7 +7,7 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
......@@ -93,9 +93,7 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
......@@ -129,8 +127,8 @@ class GenerateSchedule(PipelineSchedule):
def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
......@@ -145,19 +143,19 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module):
"""
This action is only for first stage
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
......@@ -178,18 +176,18 @@ class GenerateSchedule(PipelineSchedule):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
......@@ -233,12 +231,73 @@ class GenerateSchedule(PipelineSchedule):
return actions
def _gen_one_stage_action(self, model: Module):
"""
In this function, it will generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._load_stage_action, model))
elif self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._gen_token_action, model))
return actions
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 2:
if self.stage_manager.num_stages == 1:
return self.generate_step_one_stage(model, data_iter)
elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)
@torch.no_grad()
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 1.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_one_stage_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose:
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
......@@ -319,7 +378,7 @@ class GenerateSchedule(PipelineSchedule):
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
......@@ -330,18 +389,23 @@ class GenerateSchedule(PipelineSchedule):
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
......@@ -350,7 +414,10 @@ class GenerateSchedule(PipelineSchedule):
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage
......
......@@ -81,6 +81,8 @@ Following are the description `ShardConfig`'s arguments:
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer.
### Write your own policy
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
......@@ -184,6 +186,7 @@ class ShardConfig:
# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
use_flash_attention: bool # whether to use flash attention to speed up attention
extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
```
### Policy
......
......@@ -209,7 +209,7 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
inference_only = shard_config.extra_kwargs.get("inference_only", False)
inference_only = shard_config.extra_kwargs.get("inference_only", None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
......
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Any, Dict, Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
......@@ -34,7 +34,7 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
......
import argparse
import os
import time
import torch
from _utils import print_perf_stats
from transformers import BloomForCausalLM, BloomTokenizerFast
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def bench_bloom(args):
model_path = args.path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len
tokenizer = BloomTokenizerFast.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
"attention_mask": torch.ones((max_batch_size, max_input_len)),
}
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
iters = 10
times = []
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - max_input_len))
print_perf_stats(times, model.config, max_batch_size)
def check_bloom(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
bench_bloom(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom(args):
spawn(check_bloom, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
args = parser.parse_args()
test_bloom(args)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
transformers==4.34.0
packaging
ninja
auto-gptq==0.5.0
git+https://github.com/ModelTC/lightllm.git@28c1267cfca536b7b4f28e921e03de735b003039
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment