Commit 91ccb1cc authored by lizhigong's avatar lizhigong
Browse files

v1 engine eager tbo support mla attention

parent ba29eebb
...@@ -58,6 +58,7 @@ class TwoBatchOverlap(): ...@@ -58,6 +58,7 @@ class TwoBatchOverlap():
self.left_thread.start() self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start() self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start') logger.info('tbo:two batch overlap start')
def finish_thread(self): def finish_thread(self):
......
...@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context ...@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1 from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata( ...@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial. # Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0 common_prefix_len = 0
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
if runner.cascade_attn_enabled: if runner.cascade_attn_enabled:
common_prefix_len = runner._compute_cascade_attn_prefix_len( common_prefix_len = runner._compute_cascade_attn_prefix_len(
num_scheduled_tokens, num_scheduled_tokens,
scheduler_output. scheduler_output.
num_common_prefix_blocks[kv_cache_group_id], num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec, kv_cache_group_spec.kv_cache_spec,
runner.attn_metadata_builders[kv_cache_group_id], metadata_builder,
) )
if req_offset > 0: if req_offset > 0:
origin_block_table = runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table origin_block_table = metadata_builder.block_table.block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table[req_offset:, :] metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping origin_slot_mapping = metadata_builder.block_table.slot_mapping
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = \ metadata_builder.block_table.slot_mapping = \
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:] origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
_num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills
_num_decode_tokens_record = metadata_builder._num_decode_tokens
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
metadata_builder._num_decodes = 0
metadata_builder._num_prefills = num_reqs
metadata_builder._num_decode_tokens = 0
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
attn_metadata_i = ( attn_metadata_i = (
runner.attn_metadata_builders[kv_cache_group_id].build( metadata_builder.build(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
if req_offset > 0: if req_offset > 0:
runner.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table metadata_builder.block_table.block_table = origin_block_table
runner.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = origin_slot_mapping metadata_builder.block_table.slot_mapping = origin_slot_mapping
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
metadata_builder._num_decodes = _num_decodes_record
metadata_builder._num_prefills = _num_prefills_record
metadata_builder._num_decode_tokens = _num_decode_tokens_record
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
...@@ -319,7 +337,8 @@ def tbo_split_and_execute_model( ...@@ -319,7 +337,8 @@ def tbo_split_and_execute_model(
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
runner.vllm_config, runner.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp): num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output) runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model( model_output = runner.model(
......
...@@ -50,6 +50,7 @@ class TwoBatchOverlap(): ...@@ -50,6 +50,7 @@ class TwoBatchOverlap():
self.left_thread.start() self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start() self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start') logger.info('tbo:two batch overlap start')
def finish_thread(self): def finish_thread(self):
...@@ -90,7 +91,8 @@ class TwoBatchOverlap(): ...@@ -90,7 +91,8 @@ class TwoBatchOverlap():
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.model_runner.vllm_config, self.model_runner.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp): num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
model_output = self.model_runner.model( model_output = self.model_runner.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment