Commit 78fb5dda authored by maxiao1's avatar maxiao1
Browse files

fix bug in tbo

parent 8e838a89
......@@ -138,13 +138,6 @@ def prepare_tbo_atten_metadata(
runner.query_start_loc_np[0] = 0
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# --- seq_lens (absolute context length per-req row) ---
# Default (no split across req boundary)
# Maps rows [req_offset ... req_offset+num_reqs-1]
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs]
+ num_scheduled_tokens
)
# Offsets for copying into the *global* GPU buffers
# Left-half writes at the natural position; right-half depends on split.
......@@ -152,6 +145,10 @@ def prepare_tbo_atten_metadata(
# LEFT
seq_len_offset = 0
query_start_offset = 0
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[0 : num_reqs]
+ num_scheduled_tokens
)
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
else:
# RIGHT
......@@ -181,6 +178,10 @@ def prepare_tbo_atten_metadata(
# RIGHT without split-in-req: natural positions
seq_len_offset = req_offset
query_start_offset = req_offset
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs]
+ num_scheduled_tokens
)
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
# Copy query_start_loc into global GPU buffer window
......@@ -306,8 +307,7 @@ def tbo_split_and_execute_model(
) -> Union[ModelRunnerOutput, IntermediateTensors]:
if torch.distributed.get_rank() == 0:
print("###############enter tbo")
# If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is).
# Two-batch overlap path
# # Two-batch overlap path
split_scheduler_output(runner, scheduler_output)
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = input_split.scheduler_output_right.total_num_scheduled_tokens
......@@ -320,21 +320,21 @@ def tbo_split_and_execute_model(
)
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# 真实 token
real_L = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
real_R = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# real token nums
num_tokens_left = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
num_tokens_right = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# 按左右半批切成两份
def _split_it(it, l, r):
# split intermediate tensors
def _split_intermediate_tensors(it, l, r):
if it is None: return None, None
lm, rm = {}, {}
for k, v in it.tensors.items():
vl, vr = torch.split(v[:l + r], [l, r], dim=0)
lm[k], rm[k] = vl, vr
return IntermediateTensors(lm), IntermediateTensors(rm)
intermediate_tensors_left, intermediate_tensors_right = _split_it(
intermediate_tensors, real_L, real_R
left_tensor_map, right_tensor_map = {}, {}
for name, tensor in it.tensors.items():
vl, vr = torch.split(tensor[:l + r], [l, r], dim=0)
left_tensor_map[name], right_tensor_map[name] = vl, vr
return IntermediateTensors(left_tensor_map), IntermediateTensors(right_tensor_map)
intermediate_tensors_left, intermediate_tensors_right = _split_intermediate_tensors(
intermediate_tensors, num_tokens_left, num_tokens_right
)
runner.maybe_setup_kv_connector(scheduler_output)
......@@ -354,3 +354,5 @@ def tbo_split_and_execute_model(
finished_sending, finished_recving = runner.get_finished_kv_transfers(scheduler_output)
return model_output, finished_sending, finished_recving
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment