Commit aaef2077 authored by zhuwenwen's avatar zhuwenwen
Browse files

v1 engine eager tbo support mla attention

parent ffbb211b
......@@ -58,7 +58,8 @@ class TwoBatchOverlap():
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
logger.info('tbo:two batch overlap start')
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
......
......@@ -8,6 +8,7 @@ from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
from vllm.utils import async_tensor_h2d
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
......@@ -223,28 +224,47 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
if runner.cascade_attn_enabled:
common_prefix_len = runner._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
runner.attn_metadata_builders[kv_cache_group_id],
metadata_builder,
)
if req_offset > 0:
origin_block_table = runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = \
origin_block_table = metadata_builder.block_table.block_table
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = metadata_builder.block_table.slot_mapping
metadata_builder.block_table.slot_mapping = \
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
_num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills
_num_decode_tokens_record = metadata_builder._num_decode_tokens
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
metadata_builder._num_decodes = 0
metadata_builder._num_prefills = num_reqs
metadata_builder._num_decode_tokens = 0
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
attn_metadata_i = (
runner.attn_metadata_builders[kv_cache_group_id].build(
metadata_builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
if req_offset > 0:
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = origin_slot_mapping
metadata_builder.block_table.block_table = origin_block_table
metadata_builder.block_table.slot_mapping = origin_slot_mapping
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
metadata_builder._num_decodes = _num_decodes_record
metadata_builder._num_prefills = _num_prefills_record
metadata_builder._num_decode_tokens = _num_decode_tokens_record
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
......@@ -287,11 +307,15 @@ def tbo_split_and_execute_model(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
use_tbo = False
if len(scheduler_output.num_scheduled_tokens) > 1:
split_scheduler_output(runner, scheduler_output)
if input_split.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
input_split.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
use_tbo = False
else:
if len(scheduler_output.num_scheduled_tokens) > 1:
split_scheduler_output(runner, scheduler_output)
if input_split.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
input_split.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
if use_tbo:
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
......@@ -318,7 +342,8 @@ def tbo_split_and_execute_model(
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model(
......
......@@ -50,7 +50,8 @@ class TwoBatchOverlap():
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
logger.info('tbo:two batch overlap start')
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
......@@ -90,7 +91,8 @@ class TwoBatchOverlap():
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp):
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment