Commit 78fb5dda authored by maxiao1's avatar maxiao1
Browse files

fix bug in tbo

parent 8e838a89
...@@ -138,13 +138,6 @@ def prepare_tbo_atten_metadata( ...@@ -138,13 +138,6 @@ def prepare_tbo_atten_metadata(
runner.query_start_loc_np[0] = 0 runner.query_start_loc_np[0] = 0
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# --- seq_lens (absolute context length per-req row) ---
# Default (no split across req boundary)
# Maps rows [req_offset ... req_offset+num_reqs-1]
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs]
+ num_scheduled_tokens
)
# Offsets for copying into the *global* GPU buffers # Offsets for copying into the *global* GPU buffers
# Left-half writes at the natural position; right-half depends on split. # Left-half writes at the natural position; right-half depends on split.
...@@ -152,6 +145,10 @@ def prepare_tbo_atten_metadata( ...@@ -152,6 +145,10 @@ def prepare_tbo_atten_metadata(
# LEFT # LEFT
seq_len_offset = 0 seq_len_offset = 0
query_start_offset = 0 query_start_offset = 0
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[0 : num_reqs]
+ num_scheduled_tokens
)
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device) seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
else: else:
# RIGHT # RIGHT
...@@ -181,6 +178,10 @@ def prepare_tbo_atten_metadata( ...@@ -181,6 +178,10 @@ def prepare_tbo_atten_metadata(
# RIGHT without split-in-req: natural positions # RIGHT without split-in-req: natural positions
seq_len_offset = req_offset seq_len_offset = req_offset
query_start_offset = req_offset query_start_offset = req_offset
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs]
+ num_scheduled_tokens
)
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device) seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
# Copy query_start_loc into global GPU buffer window # Copy query_start_loc into global GPU buffer window
...@@ -306,8 +307,7 @@ def tbo_split_and_execute_model( ...@@ -306,8 +307,7 @@ def tbo_split_and_execute_model(
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("###############enter tbo") print("###############enter tbo")
# If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is). # # Two-batch overlap path
# Two-batch overlap path
split_scheduler_output(runner, scheduler_output) split_scheduler_output(runner, scheduler_output)
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = input_split.scheduler_output_right.total_num_scheduled_tokens num_input_tokens_right = input_split.scheduler_output_right.total_num_scheduled_tokens
...@@ -320,21 +320,21 @@ def tbo_split_and_execute_model( ...@@ -320,21 +320,21 @@ def tbo_split_and_execute_model(
) )
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector === # === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# 真实 token # real token nums
real_L = int(input_split.scheduler_output_left.total_num_scheduled_tokens) num_tokens_left = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
real_R = int(input_split.scheduler_output_right.total_num_scheduled_tokens) num_tokens_right = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# 按左右半批切成两份 # split intermediate tensors
def _split_it(it, l, r): def _split_intermediate_tensors(it, l, r):
if it is None: return None, None if it is None: return None, None
lm, rm = {}, {} left_tensor_map, right_tensor_map = {}, {}
for k, v in it.tensors.items(): for name, tensor in it.tensors.items():
vl, vr = torch.split(v[:l + r], [l, r], dim=0) vl, vr = torch.split(tensor[:l + r], [l, r], dim=0)
lm[k], rm[k] = vl, vr left_tensor_map[name], right_tensor_map[name] = vl, vr
return IntermediateTensors(lm), IntermediateTensors(rm) return IntermediateTensors(left_tensor_map), IntermediateTensors(right_tensor_map)
intermediate_tensors_left, intermediate_tensors_right = _split_it( intermediate_tensors_left, intermediate_tensors_right = _split_intermediate_tensors(
intermediate_tensors, real_L, real_R intermediate_tensors, num_tokens_left, num_tokens_right
) )
runner.maybe_setup_kv_connector(scheduler_output) runner.maybe_setup_kv_connector(scheduler_output)
...@@ -354,3 +354,5 @@ def tbo_split_and_execute_model( ...@@ -354,3 +354,5 @@ def tbo_split_and_execute_model(
finished_sending, finished_recving = runner.get_finished_kv_transfers(scheduler_output) finished_sending, finished_recving = runner.get_finished_kv_transfers(scheduler_output)
return model_output, finished_sending, finished_recving return model_output, finished_sending, finished_recving
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment