Commit 424cccfe authored by lizhigong's avatar lizhigong
Browse files

fix performance issues caused by enabling TBO

parent f6f8db81
......@@ -270,18 +270,16 @@ def tbo_split_and_execute_model(
inputs_embeds,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
skip_cuda_graphs: bool = True,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
use_tbo = False
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
use_tbo = False
else:
if len(scheduler_output.num_scheduled_tokens) > 1:
if len(scheduler_output.num_scheduled_tokens) > 1 and num_input_tokens > envs.VLLM_TBO_MIN_TOKENS:
split_scheduler_output(runner, scheduler_output)
if input_split.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
input_split.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
use_tbo = True
if use_tbo:
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = num_input_tokens - num_input_tokens_left
......@@ -304,11 +302,12 @@ def tbo_split_and_execute_model(
else:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
envs.VLLM_ENABLE_TBO = False
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
skip_cuda_graphs=skip_cuda_graphs):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model(
......@@ -321,4 +320,5 @@ def tbo_split_and_execute_model(
runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
runner.get_finished_kv_transfers(scheduler_output))
envs.VLLM_ENABLE_TBO = True
return model_output, finished_sending, finished_recving
\ No newline at end of file
......@@ -1377,7 +1377,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors)
inputs_embeds, scheduler_output, intermediate_tensors,
skip_cuda_graphs)
else:
# Run the model.
# Use persistent buffers for CUDA graphs.
......
......@@ -472,12 +472,12 @@ class V1ZeroModelRunner(GPUModelRunner):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
inputs_embeds, scheduler_output, intermediate_tensors)
inputs_embeds, scheduler_output, intermediate_tensors,
skip_cuda_graphs)
else:
# Run the model.
# Use persistent buffers for CUDA graphs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment