Unverified Commit c1a3ab29 authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #173 from InfiniTensor/issue/168

Issue/168 InfiniLM接入paged attention接口
parents 96e53dbb 09ab8fa4
...@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype): ...@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype):
return ctypes.c_int32 return ctypes.c_int32
elif infini_dtype == infinicore.float32: elif infini_dtype == infinicore.float32:
return ctypes.c_float return ctypes.c_float
elif infini_dtype == infinicore.int64:
return ctypes.c_int64
else: else:
raise ValueError(f"Unsupported py_dtype: {infini_dtype}") raise ValueError(f"Unsupported py_dtype: {infini_dtype}")
......
...@@ -4,7 +4,7 @@ from dataclasses import dataclass ...@@ -4,7 +4,7 @@ from dataclasses import dataclass
import infinicore import infinicore
from infinilm.auto_config import AutoConfig from infinilm.auto_config import AutoConfig
from infinilm.cache import StaticKVCacheConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.lib import _infinilm from infinilm.lib import _infinilm
...@@ -18,6 +18,7 @@ class GenerationConfig: ...@@ -18,6 +18,7 @@ class GenerationConfig:
top_p: float = 1.0 top_p: float = 1.0
eos_token_id: list[int] | None = None eos_token_id: list[int] | None = None
stop_on_eos: bool = True
class InferEngine(_infinilm.InferEngine): class InferEngine(_infinilm.InferEngine):
...@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine):
self.use_cache = False self.use_cache = False
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
...@@ -50,8 +53,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -50,8 +53,8 @@ class InferEngine(_infinilm.InferEngine):
input_ids, input_ids,
*, *,
position_ids=None, position_ids=None,
cache_lengths=None, past_kv_lengths=None,
input_lengths=None, total_kv_lengths=None,
input_offsets=None, input_offsets=None,
block_tables=None, block_tables=None,
slot_mapping=None, slot_mapping=None,
...@@ -62,8 +65,12 @@ class InferEngine(_infinilm.InferEngine): ...@@ -62,8 +65,12 @@ class InferEngine(_infinilm.InferEngine):
# TODO: Remove `_underlying` and simplify the corresponding code. # TODO: Remove `_underlying` and simplify the corresponding code.
input_ids = input_ids._underlying if input_ids is not None else None input_ids = input_ids._underlying if input_ids is not None else None
position_ids = position_ids._underlying if position_ids is not None else None position_ids = position_ids._underlying if position_ids is not None else None
cache_lengths = cache_lengths._underlying if cache_lengths is not None else None past_kv_lengths = (
input_lengths = input_lengths._underlying if input_lengths is not None else None past_kv_lengths._underlying if past_kv_lengths is not None else None
)
total_kv_lengths = (
total_kv_lengths._underlying if past_kv_lengths is not None else None
)
input_offsets = input_offsets._underlying if input_offsets is not None else None input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None block_tables = block_tables._underlying if block_tables is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
...@@ -74,8 +81,8 @@ class InferEngine(_infinilm.InferEngine): ...@@ -74,8 +81,8 @@ class InferEngine(_infinilm.InferEngine):
super().Input( super().Input(
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
cache_lengths=cache_lengths, past_sequence_lengths=past_kv_lengths,
input_lengths=input_lengths, total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
...@@ -87,21 +94,24 @@ class InferEngine(_infinilm.InferEngine): ...@@ -87,21 +94,24 @@ class InferEngine(_infinilm.InferEngine):
.output_ids .output_ids
) )
def generate(self, input_ids, generation_config, *, _measure_and_log_time=False): def generate(
self,
input_ids,
generation_config,
*,
_measure_and_log_time=False,
paged_block_size=16,
):
if generation_config.eos_token_id is None: if generation_config.eos_token_id is None:
eos_token_id = self.config.eos_token_id eos_token_id = self.config.eos_token_id
else: else:
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
# TODO: Remove the `to_numpy` calls and simplify the corresponding code. past_seq_len = 0
batch_size, seq_len = input_ids.shape[:2]
position_ids = infinicore.from_list(
[list(range(0, seq_len)) for _ in range(batch_size)], dtype=infinicore.int64
)
cache_lengths = infinicore.from_list([0], dtype=infinicore.int64)
output_ids = [] output_ids = []
initial_batch_size, initial_seqlen = input_ids.shape[:2]
seq_len = initial_seqlen
batch_size = initial_batch_size
if batch_size != 1 and generation_config.max_new_tokens is None: if batch_size != 1 and generation_config.max_new_tokens is None:
raise ValueError( raise ValueError(
...@@ -111,14 +121,75 @@ class InferEngine(_infinilm.InferEngine): ...@@ -111,14 +121,75 @@ class InferEngine(_infinilm.InferEngine):
if _measure_and_log_time: if _measure_and_log_time:
time_measurements = [] time_measurements = []
for _ in range(0, generation_config.max_new_tokens): for iter in range(0, generation_config.max_new_tokens):
if _measure_and_log_time: if _measure_and_log_time:
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size, seq_len = input_ids.shape[:2]
if self.enable_paged_attn:
input_ids = input_ids.view([1, batch_size * seq_len])
position_ids = infinicore.from_list(
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64,
)
block_tables_list = [
[
i * batch_size + b
for i in range(
(past_seq_len + seq_len + paged_block_size - 1)
// paged_block_size
)
]
for b in range(batch_size)
]
slot_mapping_list = [
(((past_seq_len + i) // paged_block_size) * batch_size + b)
* paged_block_size
+ (past_seq_len + i) % paged_block_size
for b in range(batch_size)
for i in range(seq_len)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
slot_mapping = infinicore.from_list(
slot_mapping_list,
dtype=infinicore.int64,
)
else:
position_ids = infinicore.from_list(
[
list(range(past_seq_len, past_seq_len + seq_len))
for _ in range(batch_size)
],
dtype=infinicore.int64,
)
block_tables = None
slot_mapping = None
past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
)
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
)
output_id = self( output_id = self(
input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cache_lengths=cache_lengths, past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=generation_config.temperature, temperature=generation_config.temperature,
top_k=generation_config.top_k, top_k=generation_config.top_k,
top_p=generation_config.top_p, top_p=generation_config.top_p,
...@@ -127,24 +198,17 @@ class InferEngine(_infinilm.InferEngine): ...@@ -127,24 +198,17 @@ class InferEngine(_infinilm.InferEngine):
output_ids.append(output_id) output_ids.append(output_id)
if ( if (
generation_config.max_new_tokens is not None initial_batch_size == 1
and generation_config.stop_on_eos
and generation_config.max_new_tokens is not None
and output_id.to_numpy()[0] in eos_token_id and output_id.to_numpy()[0] in eos_token_id
): ):
break break
seq_len = position_ids.shape[-1]
input_ids = infinicore.from_list( input_ids = infinicore.from_list(
[[output_id] for output_id in output_id.to_numpy().tolist()] [[output_id] for output_id in output_id.to_numpy().tolist()]
) )
position_ids = infinicore.from_list( past_seq_len = past_seq_len + seq_len
[1 for _ in range(batch_size)],
dtype=position_ids.dtype,
device=position_ids.device,
).view((batch_size, 1)) + position_ids.narrow(1, seq_len - 1, 1)
cache_lengths += infinicore.from_list(
[seq_len], dtype=cache_lengths.dtype, device=cache_lengths.device
)
if _measure_and_log_time: if _measure_and_log_time:
end_time = time.perf_counter() end_time = time.perf_counter()
...@@ -156,23 +220,21 @@ class InferEngine(_infinilm.InferEngine): ...@@ -156,23 +220,21 @@ class InferEngine(_infinilm.InferEngine):
f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms" f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms"
) )
print( print(
f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_measurements)}\n" f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n"
) )
print( print(
f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((batch_size * seq_len) / time_measurements[0], 2)}tok/s\n", f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)}tok/s\n",
) )
if len(time_measurements) > 1: if len(time_measurements) > 1:
print( print(
f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n", f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n",
) )
return output_ids return output_ids
def reset_cache(self, batch_size: int, initial_capacity: int = 1024): def reset_cache(self, cache_config):
infinicore.sync_device() infinicore.sync_device()
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
cache_config = StaticKVCacheConfig(batch_size, initial_capacity)
super().reset_cache(cache_config) super().reset_cache(cache_config)
def state_dict_keyname(self): def state_dict_keyname(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment