Commit f6f8db81 authored by lizhigong's avatar lizhigong
Browse files

fix bugs in zero overhead and tbo

parent 14201006
......@@ -25,6 +25,7 @@ class TBOModelInputSplit():
self.req_num_right = 0
self.scheduler_output_left = None
self.scheduler_output_right = None
self.query_start_loc_right = None
input_split = TBOModelInputSplit()
......@@ -136,78 +137,39 @@ def prepare_tbo_atten_metadata(
assert num_reqs > 0
seq_len_offset = req_offset
if req_offset == 0: #left
query_start_offset = 0
else:
query_start_offset = req_offset + 1
# Get the number of scheduled tokens for each request.
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(runner.arange_np[:num_reqs],
num_scheduled_tokens) + req_offset
if req_offset > 0: #right
if input_split.query_start_loc_right == None:
# TODO: create when system init
input_split.query_start_loc_right = torch.zeros(runner.max_num_reqs + 1,
dtype=torch.int32,
device=runner.device)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = runner._get_cumsum_and_arange(
num_scheduled_tokens)
cu_num_tokens, arange = runner._get_cumsum_and_arange(
num_scheduled_tokens)
# Get positions.
positions_np = runner.positions_np[:total_num_scheduled_tokens]
np.add(runner.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Prepare the attention metadata.
runner.query_start_loc_np[0] = 0
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# Calculate the slot mapping for each KV cache group.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
runner.kv_cache_config.kv_cache_groups):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = runner.input_batch.block_table[
kv_cache_group_id]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
runner.query_start_loc_np[0] = 0
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
runner.seq_lens_np[:num_reqs] = (
runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs] +
num_scheduled_tokens)
runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].copy_(
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
if req_offset > 0: #right
runner.query_start_loc[query_start_offset + num_reqs + 1:].fill_(
input_split.query_start_loc_right[0: num_reqs + 1].copy_(
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
input_split.query_start_loc_right[num_reqs + 1:].fill_(
runner.query_start_loc_cpu[num_reqs].item())
runner.seq_lens[seq_len_offset :seq_len_offset + num_reqs].copy_(runner.seq_lens_cpu[:num_reqs],
non_blocking=True)
query_start_loc = input_split.query_start_loc_right[: num_reqs + 1]
# Fill unused with -1. Needed for reshape_and_cache
if req_offset > 0: #right
runner.seq_lens[seq_len_offset + num_reqs:].fill_(0)
query_start_loc = runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1]
else:
query_start_loc = runner.query_start_loc[:num_reqs + 1]
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs]
common_attn_metadata = CommonAttentionMetadata(
......@@ -240,6 +202,9 @@ def prepare_tbo_atten_metadata(
origin_slot_mapping = metadata_builder.block_table.slot_mapping
metadata_builder.block_table.slot_mapping = \
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
origin_slot_map_cpu = metadata_builder.block_table.slot_mapping_cpu
metadata_builder.block_table.slot_mapping_cpu = \
origin_slot_map_cpu[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
_num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills
......@@ -257,6 +222,7 @@ def prepare_tbo_atten_metadata(
if req_offset > 0:
metadata_builder.block_table.block_table = origin_block_table
metadata_builder.block_table.slot_mapping = origin_slot_mapping
metadata_builder.block_table.slot_mapping_cpu = origin_slot_map_cpu
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
metadata_builder._num_decodes = _num_decodes_record
......
......@@ -80,6 +80,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request._output_token_ids[fix_offset] = generated_token_ids
request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids]
else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
......@@ -107,7 +108,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
pooler_output = pooler_outputs[req_idx]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True)
if stopped:
......@@ -118,7 +119,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
new_logprobs = logprobs.slice(req_idx, req_idx + 1)
if new_token_ids and scheduler.structured_output_manager.should_advance(
request):
......
......@@ -9,6 +9,6 @@ class ZeroV1ModelRunnerOutput(ModelRunnerOutput):
# [num_reqs]
fix_req_ids: list[str] = None
fix_sampled_token_ids:list[list[int]] = None
fix_draft_req_ids:list[list[int]] = None
fix_draft_req_ids:list[str] = None
fix_draft_tokens_ids:list[list[int]] = None
is_output_valid:bool = True
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment