Unverified Commit c1a3ab29 authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #173 from InfiniTensor/issue/168

Issue/168 InfiniLM接入paged attention接口
parents 96e53dbb 09ab8fa4
......@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype):
return ctypes.c_int32
elif infini_dtype == infinicore.float32:
return ctypes.c_float
elif infini_dtype == infinicore.int64:
return ctypes.c_int64
else:
raise ValueError(f"Unsupported py_dtype: {infini_dtype}")
......
......@@ -4,7 +4,7 @@ from dataclasses import dataclass
import infinicore
from infinilm.auto_config import AutoConfig
from infinilm.cache import StaticKVCacheConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.distributed import DistConfig
from infinilm.lib import _infinilm
......@@ -18,6 +18,7 @@ class GenerationConfig:
top_p: float = 1.0
eos_token_id: list[int] | None = None
stop_on_eos: bool = True
class InferEngine(_infinilm.InferEngine):
......@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine):
self.use_cache = False
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
......@@ -50,8 +53,8 @@ class InferEngine(_infinilm.InferEngine):
input_ids,
*,
position_ids=None,
cache_lengths=None,
input_lengths=None,
past_kv_lengths=None,
total_kv_lengths=None,
input_offsets=None,
block_tables=None,
slot_mapping=None,
......@@ -62,8 +65,12 @@ class InferEngine(_infinilm.InferEngine):
# TODO: Remove `_underlying` and simplify the corresponding code.
input_ids = input_ids._underlying if input_ids is not None else None
position_ids = position_ids._underlying if position_ids is not None else None
cache_lengths = cache_lengths._underlying if cache_lengths is not None else None
input_lengths = input_lengths._underlying if input_lengths is not None else None
past_kv_lengths = (
past_kv_lengths._underlying if past_kv_lengths is not None else None
)
total_kv_lengths = (
total_kv_lengths._underlying if past_kv_lengths is not None else None
)
input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
......@@ -74,8 +81,8 @@ class InferEngine(_infinilm.InferEngine):
super().Input(
input_ids,
position_ids=position_ids,
cache_lengths=cache_lengths,
input_lengths=input_lengths,
past_sequence_lengths=past_kv_lengths,
total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets,
block_tables=block_tables,
slot_mapping=slot_mapping,
......@@ -87,21 +94,24 @@ class InferEngine(_infinilm.InferEngine):
.output_ids
)
def generate(self, input_ids, generation_config, *, _measure_and_log_time=False):
def generate(
self,
input_ids,
generation_config,
*,
_measure_and_log_time=False,
paged_block_size=16,
):
if generation_config.eos_token_id is None:
eos_token_id = self.config.eos_token_id
else:
eos_token_id = generation_config.eos_token_id
# TODO: Remove the `to_numpy` calls and simplify the corresponding code.
batch_size, seq_len = input_ids.shape[:2]
position_ids = infinicore.from_list(
[list(range(0, seq_len)) for _ in range(batch_size)], dtype=infinicore.int64
)
cache_lengths = infinicore.from_list([0], dtype=infinicore.int64)
past_seq_len = 0
output_ids = []
initial_batch_size, initial_seqlen = input_ids.shape[:2]
seq_len = initial_seqlen
batch_size = initial_batch_size
if batch_size != 1 and generation_config.max_new_tokens is None:
raise ValueError(
......@@ -111,14 +121,75 @@ class InferEngine(_infinilm.InferEngine):
if _measure_and_log_time:
time_measurements = []
for _ in range(0, generation_config.max_new_tokens):
for iter in range(0, generation_config.max_new_tokens):
if _measure_and_log_time:
start_time = time.perf_counter()
batch_size, seq_len = input_ids.shape[:2]
if self.enable_paged_attn:
input_ids = input_ids.view([1, batch_size * seq_len])
position_ids = infinicore.from_list(
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64,
)
block_tables_list = [
[
i * batch_size + b
for i in range(
(past_seq_len + seq_len + paged_block_size - 1)
// paged_block_size
)
]
for b in range(batch_size)
]
slot_mapping_list = [
(((past_seq_len + i) // paged_block_size) * batch_size + b)
* paged_block_size
+ (past_seq_len + i) % paged_block_size
for b in range(batch_size)
for i in range(seq_len)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
slot_mapping = infinicore.from_list(
slot_mapping_list,
dtype=infinicore.int64,
)
else:
position_ids = infinicore.from_list(
[
list(range(past_seq_len, past_seq_len + seq_len))
for _ in range(batch_size)
],
dtype=infinicore.int64,
)
block_tables = None
slot_mapping = None
past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
)
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
)
output_id = self(
input_ids,
input_ids=input_ids,
position_ids=position_ids,
cache_lengths=cache_lengths,
past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=generation_config.temperature,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
......@@ -127,24 +198,17 @@ class InferEngine(_infinilm.InferEngine):
output_ids.append(output_id)
if (
generation_config.max_new_tokens is not None
initial_batch_size == 1
and generation_config.stop_on_eos
and generation_config.max_new_tokens is not None
and output_id.to_numpy()[0] in eos_token_id
):
break
seq_len = position_ids.shape[-1]
input_ids = infinicore.from_list(
[[output_id] for output_id in output_id.to_numpy().tolist()]
)
position_ids = infinicore.from_list(
[1 for _ in range(batch_size)],
dtype=position_ids.dtype,
device=position_ids.device,
).view((batch_size, 1)) + position_ids.narrow(1, seq_len - 1, 1)
cache_lengths += infinicore.from_list(
[seq_len], dtype=cache_lengths.dtype, device=cache_lengths.device
)
past_seq_len = past_seq_len + seq_len
if _measure_and_log_time:
end_time = time.perf_counter()
......@@ -156,23 +220,21 @@ class InferEngine(_infinilm.InferEngine):
f"\n\n\n Generation completed in {round(sum(time_measurements) * 1000, 2)} ms"
)
print(
f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_measurements)}\n"
f" Batchsize={initial_batch_size} Per_Batch_Input_Len={initial_seqlen} Per_Batch_New_Tokens={len(time_measurements)}\n"
)
print(
f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((batch_size * seq_len) / time_measurements[0], 2)}tok/s\n",
f" Prefill TTFT: {round(time_measurements[0], 2)}ms Throughput: {round((initial_batch_size * initial_seqlen) / time_measurements[0], 2)}tok/s\n",
)
if len(time_measurements) > 1:
print(
f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n",
f" Decode Avg ITL: {round(sum(time_measurements[1:]) * 1000 / (len(time_measurements) - 1), 2)}ms Throughput: {round((initial_batch_size * (len(time_measurements) - 1)) / sum(time_measurements[1:]), 2)}tok/s\n",
)
return output_ids
def reset_cache(self, batch_size: int, initial_capacity: int = 1024):
def reset_cache(self, cache_config):
infinicore.sync_device()
cache_config = StaticKVCacheConfig(batch_size, initial_capacity)
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
super().reset_cache(cache_config)
def state_dict_keyname(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment