Commit 6cc680ba authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/991 optimize input preparation

parent 144ba492
......@@ -123,6 +123,22 @@ class InferEngine(_infinilm.InferEngine):
if _measure_and_log_time:
time_measurements = []
block_tables = None
max_blocks_per_batch = 0
if self.enable_paged_attn:
max_blocks_per_batch = (
initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1
) // paged_block_size
block_tables_list = [
range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch)
for i in range(batch_size)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
for iter in range(0, generation_config.max_new_tokens):
if _measure_and_log_time:
start_time = time.perf_counter()
......@@ -135,28 +151,28 @@ class InferEngine(_infinilm.InferEngine):
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64,
)
block_tables_list = [
[
i * batch_size + b
if iter == 0:
slot_mapping_list = []
for b in range(batch_size):
slot_mapping_list.extend(
[
b * max_blocks_per_batch * paged_block_size + i
for i in range(seq_len)
]
)
else:
slot_mapping_list = [
i
for i in range(
(past_seq_len + seq_len + paged_block_size - 1)
// paged_block_size
past_seq_len,
max_blocks_per_batch
* paged_block_size
* initial_batch_size,
max_blocks_per_batch * paged_block_size,
)
]
for b in range(batch_size)
]
slot_mapping_list = [
(((past_seq_len + i) // paged_block_size) * batch_size + b)
* paged_block_size
+ (past_seq_len + i) % paged_block_size
for b in range(batch_size)
for i in range(seq_len)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
slot_mapping = infinicore.from_list(
slot_mapping_list,
dtype=infinicore.int64,
......@@ -170,7 +186,6 @@ class InferEngine(_infinilm.InferEngine):
dtype=infinicore.int64,
)
block_tables = None
slot_mapping = None
past_kv_lengths = infinicore.from_list(
......@@ -207,9 +222,9 @@ class InferEngine(_infinilm.InferEngine):
):
break
input_ids = infinicore.from_list(
[[output_id] for output_id in output_id.to_numpy().tolist()]
)
# start_prepare_time = time.perf_counter()
input_ids = output_id.view([batch_size, 1])
past_seq_len = past_seq_len + seq_len
if _measure_and_log_time:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment