Commit 6cc680ba authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/991 optimize input preparation

parent 144ba492
...@@ -123,6 +123,22 @@ class InferEngine(_infinilm.InferEngine): ...@@ -123,6 +123,22 @@ class InferEngine(_infinilm.InferEngine):
if _measure_and_log_time: if _measure_and_log_time:
time_measurements = [] time_measurements = []
block_tables = None
max_blocks_per_batch = 0
if self.enable_paged_attn:
max_blocks_per_batch = (
initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1
) // paged_block_size
block_tables_list = [
range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch)
for i in range(batch_size)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
for iter in range(0, generation_config.max_new_tokens): for iter in range(0, generation_config.max_new_tokens):
if _measure_and_log_time: if _measure_and_log_time:
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -135,28 +151,28 @@ class InferEngine(_infinilm.InferEngine): ...@@ -135,28 +151,28 @@ class InferEngine(_infinilm.InferEngine):
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size, list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64, dtype=infinicore.int64,
) )
block_tables_list = [
if iter == 0:
slot_mapping_list = []
for b in range(batch_size):
slot_mapping_list.extend(
[ [
i * batch_size + b b * max_blocks_per_batch * paged_block_size + i
for i in range( for i in range(seq_len)
(past_seq_len + seq_len + paged_block_size - 1)
// paged_block_size
)
]
for b in range(batch_size)
] ]
)
else:
slot_mapping_list = [ slot_mapping_list = [
(((past_seq_len + i) // paged_block_size) * batch_size + b) i
for i in range(
past_seq_len,
max_blocks_per_batch
* paged_block_size * paged_block_size
+ (past_seq_len + i) % paged_block_size * initial_batch_size,
for b in range(batch_size) max_blocks_per_batch * paged_block_size,
for i in range(seq_len) )
] ]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
slot_mapping = infinicore.from_list( slot_mapping = infinicore.from_list(
slot_mapping_list, slot_mapping_list,
dtype=infinicore.int64, dtype=infinicore.int64,
...@@ -170,7 +186,6 @@ class InferEngine(_infinilm.InferEngine): ...@@ -170,7 +186,6 @@ class InferEngine(_infinilm.InferEngine):
dtype=infinicore.int64, dtype=infinicore.int64,
) )
block_tables = None
slot_mapping = None slot_mapping = None
past_kv_lengths = infinicore.from_list( past_kv_lengths = infinicore.from_list(
...@@ -207,9 +222,9 @@ class InferEngine(_infinilm.InferEngine): ...@@ -207,9 +222,9 @@ class InferEngine(_infinilm.InferEngine):
): ):
break break
input_ids = infinicore.from_list( # start_prepare_time = time.perf_counter()
[[output_id] for output_id in output_id.to_numpy().tolist()] input_ids = output_id.view([batch_size, 1])
)
past_seq_len = past_seq_len + seq_len past_seq_len = past_seq_len + seq_len
if _measure_and_log_time: if _measure_and_log_time:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment