Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0cc7c880
Commit
0cc7c880
authored
May 22, 2025
by
lizhigong
Browse files
fix tbo to support deepseek
parent
5aa6d7c2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
306 additions
and
214 deletions
+306
-214
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+8
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+8
-2
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+8
-2
vllm/two_batch_overlap/model_input_split.py
vllm/two_batch_overlap/model_input_split.py
+278
-0
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+4
-208
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
0cc7c880
...
@@ -554,6 +554,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -554,6 +554,9 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
param
:
torch
.
nn
.
Parameter
,
...
@@ -937,8 +940,11 @@ class FusedMoE(torch.nn.Module):
...
@@ -937,8 +940,11 @@ class FusedMoE(torch.nn.Module):
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
# Default set to False. (May have to add shared expert outputs.)
# Default set to False. (May have to add shared expert outputs.)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
if
self
.
enable_tbo
:
final_hidden_states
)
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
0cc7c880
...
@@ -155,6 +155,9 @@ class DeepseekV2MoE(nn.Module):
...
@@ -155,6 +155,9 @@ class DeepseekV2MoE(nn.Module):
reduce_results
=
False
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -188,8 +191,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -188,8 +191,11 @@ class DeepseekV2MoE(nn.Module):
# final_hidden_states = final_hidden_states + shared_output \
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
# * (1. / self.routed_scaling_factor)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
if
self
.
enable_tbo
:
final_hidden_states
)
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
vllm/model_executor/models/deepseek_v3.py
View file @
0cc7c880
...
@@ -150,6 +150,9 @@ class DeepseekV3MoE(nn.Module):
...
@@ -150,6 +150,9 @@ class DeepseekV3MoE(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
)
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
,
is_enable_tbo
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
enable_tbo
=
is_enable_tbo
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
@@ -164,8 +167,11 @@ class DeepseekV3MoE(nn.Module):
...
@@ -164,8 +167,11 @@ class DeepseekV3MoE(nn.Module):
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
if
self
.
enable_tbo
:
final_hidden_states
)
final_hidden_states
=
self
.
tbo_all_reduce
(
final_hidden_states
)
else
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
vllm/two_batch_overlap/model_input_split.py
0 → 100644
View file @
0cc7c880
import
torch
from
vllm.attention.backends.flashmla
import
FlashMLAMetadata
from
vllm.attention.backends.rocm_flash_attn
import
ROCmFlashAttentionMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.utils
import
async_tensor_h2d
def
cumsum
(
lst
):
cum_lst
=
[
0
]
sum
=
0
for
i
in
range
(
0
,
len
(
lst
)):
sum
=
sum
+
lst
[
i
]
cum_lst
.
append
(
sum
)
return
cum_lst
def
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
):
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
query_tokens_split
=
[
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
]),
sum
(
model_input
.
query_lens
[
batch_size_left
:])]
batch_size_split
=
[
batch_size_left
,
batch_size_right
]
split_input_tokens
=
torch
.
split
(
model_input
.
input_tokens
,
query_tokens_split
,
dim
=
0
)
split_input_positions
=
torch
.
split
(
model_input
.
input_positions
,
query_tokens_split
,
dim
=
0
)
seq_lens_left
=
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
]
seq_lens_right
=
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:]
query_lens_left
=
model_input
.
query_lens
[
0
:
batch_size_left
]
query_lens_right
=
model_input
.
query_lens
[
batch_size_left
:]
split_seq_lens_tensor
=
torch
.
split
(
model_input
.
attn_metadata
.
seq_lens_tensor
,
batch_size_split
,
dim
=
0
)
split_block_tables
=
torch
.
split
(
model_input
.
attn_metadata
.
block_tables
,
batch_size_split
,
dim
=
0
)
num_prefills_left
=
0
num_prefills_right
=
0
num_prefill_tokens_left
=
0
num_prefill_tokens_right
=
0
num_decode_tokens_left
=
0
num_decode_tokens_right
=
0
max_prefill_seq_len_left
=
0
max_prefill_seq_len_right
=
0
max_decode_seq_len_left
=
0
max_decode_seq_len_right
=
0
max_decode_query_len_left
=
None
max_decode_query_len_right
=
None
encoder_seq_lens_left
=
None
encoder_seq_lens_right
=
None
encoder_seq_lens_tensor_left
=
None
encoder_seq_lens_tensor_right
=
None
max_encoder_seq_len_left
=
None
max_encoder_seq_len_right
=
None
num_encoder_tokens_left
=
None
num_encoder_tokens_right
=
None
cross_slot_mapping_left
=
None
cross_slot_mapping_right
=
None
cross_block_tables_left
=
None
cross_block_tables_right
=
None
if
model_input
.
is_prompt
:
num_prefills_left
=
batch_size_left
num_prefills_right
=
batch_size_right
num_prefill_tokens_left
=
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
])
num_prefill_tokens_right
=
sum
(
model_input
.
query_lens
[
batch_size_left
:])
max_prefill_seq_len_left
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
])
max_prefill_seq_len_right
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:])
else
:
num_decode_tokens_left
=
batch_size_left
num_decode_tokens_right
=
batch_size_right
max_decode_seq_len_left
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
])
max_decode_seq_len_right
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:])
split_slot_mapping
=
torch
.
split
(
model_input
.
attn_metadata
.
slot_mapping
,
query_tokens_split
,
dim
=
0
)
max_query_len_left
=
max
(
model_input
.
query_lens
[
0
:
batch_size_left
])
max_query_len_right
=
max
(
model_input
.
query_lens
[
batch_size_left
:])
zero_tensor
=
torch
.
tensor
([
0
],
device
=
self_device
,
dtype
=
torch
.
int32
)
query_start_loc_left_list
=
cumsum
(
query_lens_left
)
query_start_loc_right_list
=
cumsum
(
query_lens_right
)
query_start_loc_left
=
async_tensor_h2d
(
query_start_loc_left_list
,
torch
.
int32
,
self_device
,
True
)
query_start_loc_right
=
async_tensor_h2d
(
query_start_loc_right_list
,
torch
.
int32
,
self_device
,
True
)
seq_start_loc_left
=
torch
.
cat
((
zero_tensor
,
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)),
dim
=
0
).
to
(
torch
.
int32
)
seq_start_loc_right
=
torch
.
cat
((
zero_tensor
,
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)),
dim
=
0
).
to
(
torch
.
int32
)
split_context_lens_tensor
=
torch
.
split
(
model_input
.
attn_metadata
.
context_lens_tensor
,
batch_size_split
,
dim
=
0
)
request_ids_to_seq_ids_left
=
{}
request_ids_to_seq_ids_right
=
{}
counter
=
0
for
key
,
value
in
model_input
.
request_ids_to_seq_ids
.
items
():
if
counter
<
batch_size_left
:
request_ids_to_seq_ids_left
[
key
]
=
value
else
:
request_ids_to_seq_ids_right
[
key
]
=
value
counter
+=
1
seq_groups_left
=
None
seq_groups_right
=
None
if
model_input
.
sampling_metadata
.
seq_groups
is
not
None
:
seq_groups_left
=
model_input
.
sampling_metadata
.
seq_groups
[
0
:
batch_size_left
]
seq_groups_right
=
model_input
.
sampling_metadata
.
seq_groups
[
batch_size_left
:]
selected_token_indices_left
=
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_right
=
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)
-
1
if
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
):
block_tables_list_left
=
model_input
.
attn_metadata
.
block_tables_list
[
0
:
batch_size_left
]
block_tables_list_right
=
model_input
.
attn_metadata
.
block_tables_list
[
batch_size_left
:]
attn_metadata_left
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
0
],
max_decode_seq_len
=
max_decode_seq_len_left
,
block_tables
=
split_block_tables
[
0
],
num_prefills
=
num_prefills_left
,
num_prefill_tokens
=
num_prefill_tokens_left
,
num_decode_tokens
=
num_decode_tokens_left
,
slot_mapping
=
split_slot_mapping
[
0
],
multi_modal_placeholder_index_maps
=
{},
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens_left
,
max_prefill_seq_len
=
max_prefill_seq_len_left
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
max_query_len
=
max_query_len_left
,
query_start_loc
=
query_start_loc_left
,
seq_start_loc
=
seq_start_loc_left
,
context_lens_tensor
=
split_context_lens_tensor
[
0
],
max_decode_query_len
=
max_decode_query_len_left
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
block_tables_list_left
,
encoder_seq_lens
=
encoder_seq_lens_left
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor_left
,
max_encoder_seq_len
=
max_encoder_seq_len_left
,
num_encoder_tokens
=
num_encoder_tokens_left
,
cross_slot_mapping
=
cross_slot_mapping_left
,
cross_block_tables
=
cross_block_tables_left
,
)
attn_metadata_right
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
1
],
max_decode_seq_len
=
max_decode_seq_len_right
,
block_tables
=
split_block_tables
[
1
],
num_prefills
=
num_prefills_right
,
num_prefill_tokens
=
num_prefill_tokens_right
,
num_decode_tokens
=
num_decode_tokens_right
,
slot_mapping
=
split_slot_mapping
[
1
],
multi_modal_placeholder_index_maps
=
{},
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens_right
,
max_prefill_seq_len
=
max_prefill_seq_len_right
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
max_query_len
=
max_query_len_right
,
query_start_loc
=
query_start_loc_right
,
seq_start_loc
=
seq_start_loc_right
,
context_lens_tensor
=
split_context_lens_tensor
[
1
],
max_decode_query_len
=
max_decode_query_len_right
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
block_tables_list_right
,
encoder_seq_lens
=
encoder_seq_lens_right
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor_right
,
max_encoder_seq_len
=
max_encoder_seq_len_right
,
num_encoder_tokens
=
num_encoder_tokens_right
,
cross_slot_mapping
=
cross_slot_mapping_right
,
cross_block_tables
=
cross_block_tables_right
,
)
if
isinstance
(
model_input
.
attn_metadata
,
FlashMLAMetadata
):
attn_metadata_left
=
FlashMLAMetadata
(
num_prefills
=
num_prefills_left
,
num_prefill_tokens
=
num_prefill_tokens_left
,
num_decode_tokens
=
num_decode_tokens_left
,
slot_mapping
=
split_slot_mapping
[
0
],
multi_modal_placeholder_index_maps
=
model_input
.
attn_metadata
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
input_positions
=
split_input_positions
[
0
],
seq_lens
=
seq_lens_left
,
seq_lens_tensor
=
split_seq_lens_tensor
[
0
],
max_prefill_seq_len
=
max_prefill_seq_len_left
,
max_decode_seq_len
=
max_decode_seq_len_left
,
context_lens_tensor
=
split_context_lens_tensor
[
0
],
block_tables
=
split_block_tables
[
0
],
max_query_len
=
max_query_len_left
,
max_decode_query_len
=
max_decode_query_len_left
,
query_start_loc
=
query_start_loc_left
,
seq_start_loc
=
seq_start_loc_left
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
head_dim
=
model_input
.
attn_metadata
.
head_dim
,
is_profile_run
=
model_input
.
attn_metadata
.
is_profile_run
,
context_chunk_cu_seq_lens
=
model_input
.
attn_metadata
.
context_chunk_cu_seq_lens
,
context_chunk_starts
=
model_input
.
attn_metadata
.
context_chunk_starts
,
context_chunk_seq_tot
=
model_input
.
attn_metadata
.
context_chunk_seq_tot
,
context_chunk_max_seq_lens
=
model_input
.
attn_metadata
.
context_chunk_max_seq_lens
,
context_chunk_workspace
=
model_input
.
attn_metadata
.
context_chunk_workspace
,
decode_tile_scheduler_metadata
=
model_input
.
attn_metadata
.
decode_tile_scheduler_metadata
,
decode_num_splits
=
model_input
.
attn_metadata
.
decode_num_splits
)
attn_metadata_right
=
FlashMLAMetadata
(
num_prefills
=
num_prefills_right
,
num_prefill_tokens
=
num_prefill_tokens_right
,
num_decode_tokens
=
num_decode_tokens_right
,
slot_mapping
=
split_slot_mapping
[
1
],
multi_modal_placeholder_index_maps
=
model_input
.
attn_metadata
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
input_positions
=
split_input_positions
[
1
],
seq_lens
=
seq_lens_right
,
seq_lens_tensor
=
split_seq_lens_tensor
[
1
],
max_prefill_seq_len
=
max_prefill_seq_len_right
,
max_decode_seq_len
=
max_decode_seq_len_right
,
context_lens_tensor
=
split_context_lens_tensor
[
1
],
block_tables
=
split_block_tables
[
1
],
max_query_len
=
max_query_len_right
,
max_decode_query_len
=
max_decode_query_len_right
,
query_start_loc
=
query_start_loc_right
,
seq_start_loc
=
seq_start_loc_right
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
head_dim
=
model_input
.
attn_metadata
.
head_dim
,
is_profile_run
=
model_input
.
attn_metadata
.
is_profile_run
,
context_chunk_cu_seq_lens
=
model_input
.
attn_metadata
.
context_chunk_cu_seq_lens
,
context_chunk_starts
=
model_input
.
attn_metadata
.
context_chunk_starts
,
context_chunk_seq_tot
=
model_input
.
attn_metadata
.
context_chunk_seq_tot
,
context_chunk_max_seq_lens
=
model_input
.
attn_metadata
.
context_chunk_max_seq_lens
,
context_chunk_workspace
=
model_input
.
attn_metadata
.
context_chunk_workspace
,
decode_tile_scheduler_metadata
=
model_input
.
attn_metadata
.
decode_tile_scheduler_metadata
,
decode_num_splits
=
model_input
.
attn_metadata
.
decode_num_splits
)
model_input_left
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
split_input_tokens
[
0
],
input_positions
=
split_input_positions
[
0
],
token_types
=
None
,
seq_lens
=
seq_lens_left
,
query_lens
=
query_lens_left
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
attn_metadata
=
attn_metadata_left
,
prompt_adapter_mapping
=
model_input
.
prompt_adapter_mapping
,
prompt_adapter_requests
=
model_input
.
prompt_adapter_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids_left
,
finished_requests_ids
=
model_input
.
finished_requests_ids
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_left
,
selected_token_indices
=
selected_token_indices_left
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_left
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
)
model_input_right
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
split_input_tokens
[
1
],
input_positions
=
split_input_positions
[
1
],
token_types
=
None
,
seq_lens
=
seq_lens_right
,
query_lens
=
query_lens_right
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
attn_metadata
=
attn_metadata_right
,
prompt_adapter_mapping
=
model_input
.
prompt_adapter_mapping
,
prompt_adapter_requests
=
model_input
.
prompt_adapter_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids_right
,
finished_requests_ids
=
model_input
.
finished_requests_ids
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_right
,
selected_token_indices
=
selected_token_indices_right
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_right
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
)
return
model_input_left
,
model_input_right
vllm/two_batch_overlap/two_batch_overlap.py
View file @
0cc7c880
...
@@ -3,12 +3,13 @@ import os
...
@@ -3,12 +3,13 @@ import os
import
queue
import
queue
import
threading
import
threading
import
torch
import
torch
from
vllm.attention.backends.flashmla
import
FlashMLAMetadata
from
vllm.attention.backends.rocm_flash_attn
import
ROCmFlashAttentionMetadata
from
vllm.attention.backends.rocm_flash_attn
import
ROCmFlashAttentionMetadata
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.two_batch_overlap.forward_context
import
init_tbo_forward_context
from
vllm.two_batch_overlap.model_input_split
import
split_model_input
from
vllm.utils
import
async_tensor_h2d
from
vllm.utils
import
async_tensor_h2d
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
...
@@ -203,212 +204,6 @@ def tbo_all_reduce(obj):
...
@@ -203,212 +204,6 @@ def tbo_all_reduce(obj):
return
output
return
output
return
tensor_model_parallel_all_reduce
(
obj
)
return
tensor_model_parallel_all_reduce
(
obj
)
def
cumsum
(
lst
):
cum_lst
=
[
0
]
sum
=
0
for
i
in
range
(
0
,
len
(
lst
)):
sum
=
sum
+
lst
[
i
]
cum_lst
.
append
(
sum
)
return
cum_lst
def
split_model_input
(
model_input
,
self_device
,
batch_size_left
,
batch_size_right
):
query_tokens_split
=
[
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
]),
sum
(
model_input
.
query_lens
[
batch_size_left
:])]
batch_size_split
=
[
batch_size_left
,
batch_size_right
]
split_input_tokens
=
torch
.
split
(
model_input
.
input_tokens
,
query_tokens_split
,
dim
=
0
)
split_input_positions
=
torch
.
split
(
model_input
.
input_positions
,
query_tokens_split
,
dim
=
0
)
seq_lens_left
=
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
]
seq_lens_right
=
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:]
query_lens_left
=
model_input
.
query_lens
[
0
:
batch_size_left
]
query_lens_right
=
model_input
.
query_lens
[
batch_size_left
:]
split_seq_lens_tensor
=
torch
.
split
(
model_input
.
attn_metadata
.
seq_lens_tensor
,
batch_size_split
,
dim
=
0
)
split_block_tables
=
torch
.
split
(
model_input
.
attn_metadata
.
block_tables
,
batch_size_split
,
dim
=
0
)
num_prefills_left
=
0
num_prefills_right
=
0
num_prefill_tokens_left
=
0
num_prefill_tokens_right
=
0
num_decode_tokens_left
=
0
num_decode_tokens_right
=
0
max_prefill_seq_len_left
=
0
max_prefill_seq_len_right
=
0
max_decode_seq_len_left
=
0
max_decode_seq_len_right
=
0
max_decode_query_len_left
=
None
max_decode_query_len_right
=
None
encoder_seq_lens_left
=
None
encoder_seq_lens_right
=
None
encoder_seq_lens_tensor_left
=
None
encoder_seq_lens_tensor_right
=
None
max_encoder_seq_len_left
=
None
max_encoder_seq_len_right
=
None
num_encoder_tokens_left
=
None
num_encoder_tokens_right
=
None
cross_slot_mapping_left
=
None
cross_slot_mapping_right
=
None
cross_block_tables_left
=
None
cross_block_tables_right
=
None
if
model_input
.
is_prompt
:
num_prefills_left
=
batch_size_left
num_prefills_right
=
batch_size_right
num_prefill_tokens_left
=
sum
(
model_input
.
query_lens
[
0
:
batch_size_left
])
num_prefill_tokens_right
=
sum
(
model_input
.
query_lens
[
batch_size_left
:])
max_prefill_seq_len_left
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
])
max_prefill_seq_len_right
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:])
else
:
num_decode_tokens_left
=
batch_size_left
num_decode_tokens_right
=
batch_size_right
max_decode_seq_len_left
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
0
:
batch_size_left
])
max_decode_seq_len_right
=
max
(
model_input
.
attn_metadata
.
seq_lens
[
batch_size_left
:])
split_slot_mapping
=
torch
.
split
(
model_input
.
attn_metadata
.
slot_mapping
,
query_tokens_split
,
dim
=
0
)
max_query_len_left
=
max
(
model_input
.
query_lens
[
0
:
batch_size_left
])
max_query_len_right
=
max
(
model_input
.
query_lens
[
batch_size_left
:])
zero_tensor
=
torch
.
tensor
([
0
],
device
=
self_device
,
dtype
=
torch
.
int32
)
query_start_loc_left_list
=
cumsum
(
query_lens_left
)
query_start_loc_right_list
=
cumsum
(
query_lens_right
)
query_start_loc_left
=
async_tensor_h2d
(
query_start_loc_left_list
,
torch
.
int32
,
self_device
,
True
)
query_start_loc_right
=
async_tensor_h2d
(
query_start_loc_right_list
,
torch
.
int32
,
self_device
,
True
)
seq_start_loc_left
=
torch
.
cat
((
zero_tensor
,
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)),
dim
=
0
).
to
(
torch
.
int32
)
seq_start_loc_right
=
torch
.
cat
((
zero_tensor
,
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)),
dim
=
0
).
to
(
torch
.
int32
)
split_context_lens_tensor
=
torch
.
split
(
model_input
.
attn_metadata
.
context_lens_tensor
,
batch_size_split
,
dim
=
0
)
block_tables_list_left
=
model_input
.
attn_metadata
.
block_tables_list
[
0
:
batch_size_left
]
block_tables_list_right
=
model_input
.
attn_metadata
.
block_tables_list
[
batch_size_left
:]
request_ids_to_seq_ids_left
=
{}
request_ids_to_seq_ids_right
=
{}
counter
=
0
for
key
,
value
in
model_input
.
request_ids_to_seq_ids
.
items
():
if
counter
<
batch_size_left
:
request_ids_to_seq_ids_left
[
key
]
=
value
else
:
request_ids_to_seq_ids_right
[
key
]
=
value
counter
+=
1
seq_groups_left
=
None
seq_groups_right
=
None
if
model_input
.
sampling_metadata
.
seq_groups
is
not
None
:
seq_groups_left
=
model_input
.
sampling_metadata
.
seq_groups
[
0
:
batch_size_left
]
seq_groups_right
=
model_input
.
sampling_metadata
.
seq_groups
[
batch_size_left
:]
selected_token_indices_left
=
split_seq_lens_tensor
[
0
].
cumsum
(
dim
=
0
)
-
1
selected_token_indices_right
=
split_seq_lens_tensor
[
1
].
cumsum
(
dim
=
0
)
-
1
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
attn_metadata_left
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
0
],
max_decode_seq_len
=
max_decode_seq_len_left
,
block_tables
=
split_block_tables
[
0
],
num_prefills
=
num_prefills_left
,
num_prefill_tokens
=
num_prefill_tokens_left
,
num_decode_tokens
=
num_decode_tokens_left
,
slot_mapping
=
split_slot_mapping
[
0
],
multi_modal_placeholder_index_maps
=
{},
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens_left
,
max_prefill_seq_len
=
max_prefill_seq_len_left
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
max_query_len
=
max_query_len_left
,
query_start_loc
=
query_start_loc_left
,
seq_start_loc
=
seq_start_loc_left
,
context_lens_tensor
=
split_context_lens_tensor
[
0
],
max_decode_query_len
=
max_decode_query_len_left
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
block_tables_list_left
,
encoder_seq_lens
=
encoder_seq_lens_left
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor_left
,
max_encoder_seq_len
=
max_encoder_seq_len_left
,
num_encoder_tokens
=
num_encoder_tokens_left
,
cross_slot_mapping
=
cross_slot_mapping_left
,
cross_block_tables
=
cross_block_tables_left
,
)
model_input_left
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
split_input_tokens
[
0
],
input_positions
=
split_input_positions
[
0
],
token_types
=
None
,
seq_lens
=
seq_lens_left
,
query_lens
=
query_lens_left
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
attn_metadata
=
attn_metadata_left
,
prompt_adapter_mapping
=
model_input
.
prompt_adapter_mapping
,
prompt_adapter_requests
=
model_input
.
prompt_adapter_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids_left
,
finished_requests_ids
=
model_input
.
finished_requests_ids
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_left
,
selected_token_indices
=
selected_token_indices_left
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_left
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
)
attn_metadata_right
=
ROCmFlashAttentionMetadata
(
seq_lens_tensor
=
split_seq_lens_tensor
[
1
],
max_decode_seq_len
=
max_decode_seq_len_right
,
block_tables
=
split_block_tables
[
1
],
num_prefills
=
num_prefills_right
,
num_prefill_tokens
=
num_prefill_tokens_right
,
num_decode_tokens
=
num_decode_tokens_right
,
slot_mapping
=
split_slot_mapping
[
1
],
multi_modal_placeholder_index_maps
=
{},
enable_kv_scales_calculation
=
model_input
.
attn_metadata
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens_right
,
max_prefill_seq_len
=
max_prefill_seq_len_right
,
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
,
max_query_len
=
max_query_len_right
,
query_start_loc
=
query_start_loc_right
,
seq_start_loc
=
seq_start_loc_right
,
context_lens_tensor
=
split_context_lens_tensor
[
1
],
max_decode_query_len
=
max_decode_query_len_right
,
_cached_prefill_metadata
=
None
,
_cached_decode_metadata
=
None
,
tree_attention_masks_tensor
=
None
,
block_tables_list
=
block_tables_list_right
,
encoder_seq_lens
=
encoder_seq_lens_right
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor_right
,
max_encoder_seq_len
=
max_encoder_seq_len_right
,
num_encoder_tokens
=
num_encoder_tokens_right
,
cross_slot_mapping
=
cross_slot_mapping_right
,
cross_block_tables
=
cross_block_tables_right
,
)
model_input_right
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
split_input_tokens
[
1
],
input_positions
=
split_input_positions
[
1
],
token_types
=
None
,
seq_lens
=
seq_lens_right
,
query_lens
=
query_lens_right
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
attn_metadata
=
attn_metadata_right
,
prompt_adapter_mapping
=
model_input
.
prompt_adapter_mapping
,
prompt_adapter_requests
=
model_input
.
prompt_adapter_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids_right
,
finished_requests_ids
=
model_input
.
finished_requests_ids
,
virtual_engine
=
model_input
.
virtual_engine
,
async_callback
=
model_input
.
async_callback
,
scheduler_outputs
=
model_input
.
scheduler_outputs
,
previous_hidden_states
=
model_input
.
previous_hidden_states
,
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups_right
,
selected_token_indices
=
selected_token_indices_right
,
categorized_sample_indices
=
model_input
.
sampling_metadata
.
categorized_sample_indices
,
num_prompts
=
num_prefills_right
,
skip_sampler_cpu_output
=
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
,
reuse_sampling_tensors
=
model_input
.
sampling_metadata
.
reuse_sampling_tensors
,
),
is_prompt
=
model_input
.
is_prompt
,
)
return
model_input_left
,
model_input_right
def
merge_model_output
(
states_left
,
states_right
):
def
merge_model_output
(
states_left
,
states_right
):
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
output
=
torch
.
concat
([
states_left
,
states_right
],
dim
=
0
)
return
output
return
output
...
@@ -426,11 +221,12 @@ def tbo_model_executable(
...
@@ -426,11 +221,12 @@ def tbo_model_executable(
):
):
init_two_batch_overlap
()
init_two_batch_overlap
()
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_rocm_fa
=
isinstance
(
model_input
.
attn_metadata
,
ROCmFlashAttentionMetadata
)
is_mla_fa
=
isinstance
(
model_input
.
attn_metadata
,
FlashMLAMetadata
)
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
is_cuda_graph_decode
=
model_input
.
attn_metadata
.
use_cuda_graph
and
not
model_input
.
is_prompt
batch_size
=
len
(
model_input
.
attn_metadata
.
seq_lens
)
batch_size
=
len
(
model_input
.
attn_metadata
.
seq_lens
)
if
batch_size
==
1
or
\
if
batch_size
==
1
or
\
(
not
model_input
.
is_prompt
and
not
enable_tbo_decode
)
or
\
(
not
model_input
.
is_prompt
and
not
enable_tbo_decode
)
or
\
not
is_rocm_fa
or
\
not
(
is_rocm_fa
or
is_mla_fa
)
or
\
is_cuda_graph_decode
:
is_cuda_graph_decode
:
with
set_forward_context
(
model_input
.
attn_metadata
,
with
set_forward_context
(
model_input
.
attn_metadata
,
vllm_config
,
virtual_engine
):
vllm_config
,
virtual_engine
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment