Commit 231a170a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

fix tbo to support deepseek

See merge request dcutoolkit/deeplearing/vllm!118
parents 3b5d646e 0cc7c880
......@@ -554,6 +554,9 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
......@@ -937,8 +940,11 @@ class FusedMoE(torch.nn.Module):
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.enable_tbo:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states
......
......@@ -155,6 +155,9 @@ class DeepseekV2MoE(nn.Module):
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -188,8 +191,11 @@ class DeepseekV2MoE(nn.Module):
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.enable_tbo:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
......
......@@ -150,6 +150,9 @@ class DeepseekV3MoE(nn.Module):
quant_config=quant_config,
reduce_results=False,
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -164,8 +167,11 @@ class DeepseekV3MoE(nn.Module):
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.enable_tbo:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
......
import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import async_tensor_h2d
def cumsum(lst):
cum_lst = [0]
sum = 0
for i in range(0, len(lst)):
sum = sum + lst[i]
cum_lst.append(sum)
return cum_lst
def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
batch_size_split = [batch_size_left, batch_size_right]
split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0)
split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0)
seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left]
seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:]
query_lens_left = model_input.query_lens[0:batch_size_left]
query_lens_right = model_input.query_lens[batch_size_left:]
split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0)
num_prefills_left = 0
num_prefills_right = 0
num_prefill_tokens_left = 0
num_prefill_tokens_right = 0
num_decode_tokens_left = 0
num_decode_tokens_right = 0
max_prefill_seq_len_left = 0
max_prefill_seq_len_right = 0
max_decode_seq_len_left = 0
max_decode_seq_len_right = 0
max_decode_query_len_left = None
max_decode_query_len_right = None
encoder_seq_lens_left = None
encoder_seq_lens_right = None
encoder_seq_lens_tensor_left = None
encoder_seq_lens_tensor_right = None
max_encoder_seq_len_left = None
max_encoder_seq_len_right = None
num_encoder_tokens_left = None
num_encoder_tokens_right = None
cross_slot_mapping_left = None
cross_slot_mapping_right = None
cross_block_tables_left = None
cross_block_tables_right = None
if model_input.is_prompt:
num_prefills_left = batch_size_left
num_prefills_right = batch_size_right
num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left])
num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:])
max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
else:
num_decode_tokens_left = batch_size_left
num_decode_tokens_right = batch_size_right
max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0)
max_query_len_left = max(model_input.query_lens[0:batch_size_left])
max_query_len_right = max(model_input.query_lens[batch_size_left:])
zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32)
query_start_loc_left_list = cumsum(query_lens_left)
query_start_loc_right_list = cumsum(query_lens_right)
query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32,
self_device,
True)
query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32,
self_device,
True)
seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32)
seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32)
split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0)
request_ids_to_seq_ids_left = {}
request_ids_to_seq_ids_right = {}
counter = 0
for key, value in model_input.request_ids_to_seq_ids.items():
if counter < batch_size_left:
request_ids_to_seq_ids_left[key] = value
else:
request_ids_to_seq_ids_right[key] = value
counter += 1
seq_groups_left = None
seq_groups_right = None
if model_input.sampling_metadata.seq_groups is not None:
seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
if isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata):
block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
attn_metadata_left = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[0],
max_decode_seq_len = max_decode_seq_len_left,
block_tables = split_block_tables[0],
num_prefills = num_prefills_left,
num_prefill_tokens = num_prefill_tokens_left,
num_decode_tokens = num_decode_tokens_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = {},
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
seq_lens = seq_lens_left,
max_prefill_seq_len = max_prefill_seq_len_left,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
max_query_len = max_query_len_left,
query_start_loc = query_start_loc_left,
seq_start_loc = seq_start_loc_left,
context_lens_tensor = split_context_lens_tensor[0],
max_decode_query_len = max_decode_query_len_left,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = block_tables_list_left,
encoder_seq_lens = encoder_seq_lens_left,
encoder_seq_lens_tensor = encoder_seq_lens_tensor_left,
max_encoder_seq_len = max_encoder_seq_len_left,
num_encoder_tokens = num_encoder_tokens_left,
cross_slot_mapping = cross_slot_mapping_left,
cross_block_tables = cross_block_tables_left,
)
attn_metadata_right = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[1],
max_decode_seq_len = max_decode_seq_len_right,
block_tables = split_block_tables[1],
num_prefills = num_prefills_right,
num_prefill_tokens = num_prefill_tokens_right,
num_decode_tokens = num_decode_tokens_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = {},
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
seq_lens = seq_lens_right,
max_prefill_seq_len = max_prefill_seq_len_right,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
max_query_len = max_query_len_right,
query_start_loc = query_start_loc_right,
seq_start_loc = seq_start_loc_right,
context_lens_tensor = split_context_lens_tensor[1],
max_decode_query_len = max_decode_query_len_right,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = block_tables_list_right,
encoder_seq_lens = encoder_seq_lens_right,
encoder_seq_lens_tensor = encoder_seq_lens_tensor_right,
max_encoder_seq_len = max_encoder_seq_len_right,
num_encoder_tokens = num_encoder_tokens_right,
cross_slot_mapping = cross_slot_mapping_right,
cross_block_tables = cross_block_tables_right,
)
if isinstance(model_input.attn_metadata, FlashMLAMetadata):
attn_metadata_left = FlashMLAMetadata(
num_prefills = num_prefills_left,
num_prefill_tokens = num_prefill_tokens_left,
num_decode_tokens = num_decode_tokens_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
input_positions = split_input_positions[0],
seq_lens = seq_lens_left,
seq_lens_tensor = split_seq_lens_tensor[0],
max_prefill_seq_len = max_prefill_seq_len_left,
max_decode_seq_len = max_decode_seq_len_left,
context_lens_tensor = split_context_lens_tensor[0],
block_tables = split_block_tables[0],
max_query_len = max_query_len_left,
max_decode_query_len = max_decode_query_len_left,
query_start_loc = query_start_loc_left,
seq_start_loc = seq_start_loc_left,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
head_dim = model_input.attn_metadata.head_dim,
is_profile_run = model_input.attn_metadata.is_profile_run,
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
decode_tile_scheduler_metadata=model_input.attn_metadata.decode_tile_scheduler_metadata,
decode_num_splits=model_input.attn_metadata.decode_num_splits
)
attn_metadata_right = FlashMLAMetadata(
num_prefills = num_prefills_right,
num_prefill_tokens = num_prefill_tokens_right,
num_decode_tokens = num_decode_tokens_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
input_positions = split_input_positions[1],
seq_lens = seq_lens_right,
seq_lens_tensor = split_seq_lens_tensor[1],
max_prefill_seq_len = max_prefill_seq_len_right,
max_decode_seq_len = max_decode_seq_len_right,
context_lens_tensor = split_context_lens_tensor[1],
block_tables = split_block_tables[1],
max_query_len = max_query_len_right,
max_decode_query_len = max_decode_query_len_right,
query_start_loc = query_start_loc_right,
seq_start_loc = seq_start_loc_right,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
head_dim = model_input.attn_metadata.head_dim,
is_profile_run = model_input.attn_metadata.is_profile_run,
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
decode_tile_scheduler_metadata=model_input.attn_metadata.decode_tile_scheduler_metadata,
decode_num_splits=model_input.attn_metadata.decode_num_splits
)
model_input_left = ModelInputForGPUWithSamplingMetadata(
input_tokens=split_input_tokens[0],
input_positions=split_input_positions[0],
token_types=None,
seq_lens=seq_lens_left,
query_lens=query_lens_left,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=attn_metadata_left,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids_left,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_left,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt,
)
model_input_right = ModelInputForGPUWithSamplingMetadata(
input_tokens=split_input_tokens[1],
input_positions=split_input_positions[1],
token_types=None,
seq_lens=seq_lens_right,
query_lens=query_lens_right,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=attn_metadata_right,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids_right,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_right,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt,
)
return model_input_left, model_input_right
......@@ -3,12 +3,13 @@ import os
import queue
import threading
import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import split_model_input
from vllm.utils import async_tensor_h2d
from vllm.logger import init_logger
from vllm.profiler.prof import profile
......@@ -203,212 +204,6 @@ def tbo_all_reduce(obj):
return output
return tensor_model_parallel_all_reduce(obj)
def cumsum(lst):
cum_lst = [0]
sum = 0
for i in range(0, len(lst)):
sum = sum + lst[i]
cum_lst.append(sum)
return cum_lst
def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
batch_size_split = [batch_size_left, batch_size_right]
split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0)
split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0)
seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left]
seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:]
query_lens_left = model_input.query_lens[0:batch_size_left]
query_lens_right = model_input.query_lens[batch_size_left:]
split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0)
num_prefills_left = 0
num_prefills_right = 0
num_prefill_tokens_left = 0
num_prefill_tokens_right = 0
num_decode_tokens_left = 0
num_decode_tokens_right = 0
max_prefill_seq_len_left = 0
max_prefill_seq_len_right = 0
max_decode_seq_len_left = 0
max_decode_seq_len_right = 0
max_decode_query_len_left = None
max_decode_query_len_right = None
encoder_seq_lens_left = None
encoder_seq_lens_right = None
encoder_seq_lens_tensor_left = None
encoder_seq_lens_tensor_right = None
max_encoder_seq_len_left = None
max_encoder_seq_len_right = None
num_encoder_tokens_left = None
num_encoder_tokens_right = None
cross_slot_mapping_left = None
cross_slot_mapping_right = None
cross_block_tables_left = None
cross_block_tables_right = None
if model_input.is_prompt:
num_prefills_left = batch_size_left
num_prefills_right = batch_size_right
num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left])
num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:])
max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
else:
num_decode_tokens_left = batch_size_left
num_decode_tokens_right = batch_size_right
max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0)
max_query_len_left = max(model_input.query_lens[0:batch_size_left])
max_query_len_right = max(model_input.query_lens[batch_size_left:])
zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32)
query_start_loc_left_list = cumsum(query_lens_left)
query_start_loc_right_list = cumsum(query_lens_right)
query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32,
self_device,
True)
query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32,
self_device,
True)
seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32)
seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32)
split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0)
block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
request_ids_to_seq_ids_left = {}
request_ids_to_seq_ids_right = {}
counter = 0
for key, value in model_input.request_ids_to_seq_ids.items():
if counter < batch_size_left:
request_ids_to_seq_ids_left[key] = value
else:
request_ids_to_seq_ids_right[key] = value
counter += 1
seq_groups_left = None
seq_groups_right = None
if model_input.sampling_metadata.seq_groups is not None:
seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
attn_metadata_left = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[0],
max_decode_seq_len = max_decode_seq_len_left,
block_tables = split_block_tables[0],
num_prefills = num_prefills_left,
num_prefill_tokens = num_prefill_tokens_left,
num_decode_tokens = num_decode_tokens_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = {},
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
seq_lens = seq_lens_left,
max_prefill_seq_len = max_prefill_seq_len_left,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
max_query_len = max_query_len_left,
query_start_loc = query_start_loc_left,
seq_start_loc = seq_start_loc_left,
context_lens_tensor = split_context_lens_tensor[0],
max_decode_query_len = max_decode_query_len_left,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = block_tables_list_left,
encoder_seq_lens = encoder_seq_lens_left,
encoder_seq_lens_tensor = encoder_seq_lens_tensor_left,
max_encoder_seq_len = max_encoder_seq_len_left,
num_encoder_tokens = num_encoder_tokens_left,
cross_slot_mapping = cross_slot_mapping_left,
cross_block_tables = cross_block_tables_left,
)
model_input_left = ModelInputForGPUWithSamplingMetadata(
input_tokens=split_input_tokens[0],
input_positions=split_input_positions[0],
token_types=None,
seq_lens=seq_lens_left,
query_lens=query_lens_left,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=attn_metadata_left,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids_left,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_left,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt,
)
attn_metadata_right = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[1],
max_decode_seq_len = max_decode_seq_len_right,
block_tables = split_block_tables[1],
num_prefills = num_prefills_right,
num_prefill_tokens = num_prefill_tokens_right,
num_decode_tokens = num_decode_tokens_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = {},
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
seq_lens = seq_lens_right,
max_prefill_seq_len = max_prefill_seq_len_right,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
max_query_len = max_query_len_right,
query_start_loc = query_start_loc_right,
seq_start_loc = seq_start_loc_right,
context_lens_tensor = split_context_lens_tensor[1],
max_decode_query_len = max_decode_query_len_right,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = block_tables_list_right,
encoder_seq_lens = encoder_seq_lens_right,
encoder_seq_lens_tensor = encoder_seq_lens_tensor_right,
max_encoder_seq_len = max_encoder_seq_len_right,
num_encoder_tokens = num_encoder_tokens_right,
cross_slot_mapping = cross_slot_mapping_right,
cross_block_tables = cross_block_tables_right,
)
model_input_right = ModelInputForGPUWithSamplingMetadata(
input_tokens=split_input_tokens[1],
input_positions=split_input_positions[1],
token_types=None,
seq_lens=seq_lens_right,
query_lens=query_lens_right,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=attn_metadata_right,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids_right,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_right,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt,
)
return model_input_left, model_input_right
def merge_model_output(states_left, states_right):
output = torch.concat([states_left, states_right], dim=0)
return output
......@@ -426,11 +221,12 @@ def tbo_model_executable(
):
init_two_batch_overlap()
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
is_mla_fa = isinstance(model_input.attn_metadata, FlashMLAMetadata)
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
batch_size = len(model_input.attn_metadata.seq_lens)
if batch_size == 1 or \
(not model_input.is_prompt and not enable_tbo_decode) or \
not is_rocm_fa or \
not (is_rocm_fa or is_mla_fa) or \
is_cuda_graph_decode:
with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment