Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cda54326
Commit
cda54326
authored
Jan 12, 2026
by
王敏
Browse files
[feat]添加dp attention功能
parent
e89003dd
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
412 additions
and
93 deletions
+412
-93
vllm/config.py
vllm/config.py
+13
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-0
vllm/forward_context.py
vllm/forward_context.py
+12
-1
vllm/model_executor/layers/dp_attention.py
vllm/model_executor/layers/dp_attention.py
+133
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+12
-7
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+54
-43
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+37
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+30
-6
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+14
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+68
-18
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+15
-15
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+7
-0
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+9
-0
No files found.
vllm/config.py
View file @
cda54326
...
@@ -1883,6 +1883,9 @@ class ParallelConfig:
...
@@ -1883,6 +1883,9 @@ class ParallelConfig:
""" Use data parallelism instead of tensor parallelism for vision encoder.
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
Only support LLama4 for now"""
enable_dp_attention
:
bool
=
False
"""Enable dp attention"""
@
property
@
property
def
world_size_across_dp
(
self
)
->
int
:
def
world_size_across_dp
(
self
)
->
int
:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
"""world_size_across_dp is TPxPPxDP, it is the size of the world
...
@@ -2108,6 +2111,9 @@ class ParallelConfig:
...
@@ -2108,6 +2111,9 @@ class ParallelConfig:
if
self
.
ray_workers_use_nsight
and
not
self
.
use_ray
:
if
self
.
ray_workers_use_nsight
and
not
self
.
use_ray
:
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
"run with Ray."
)
if
self
.
enable_dp_attention
and
self
.
enable_expert_parallel
:
raise
ValueError
(
"Dp attention and expert parallel can not enable together."
)
return
self
return
self
...
@@ -4805,6 +4811,7 @@ class VllmConfig:
...
@@ -4805,6 +4811,7 @@ class VllmConfig:
dp_size
=
self
.
parallel_config
.
data_parallel_size
dp_size
=
self
.
parallel_config
.
data_parallel_size
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
ep_sp
=
self
.
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
tp_size
>
1
ep_sp
=
self
.
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
tp_size
>
1
enable_dp_attention
=
self
.
parallel_config
.
enable_dp_attention
# add for spec decode
# add for spec decode
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
...
@@ -4813,11 +4820,15 @@ class VllmConfig:
...
@@ -4813,11 +4820,15 @@ class VllmConfig:
batch_size_capture_list
=
sorted
(
set
(
batch_size_capture_list
+
mtp_batch_size_capture_list
))
batch_size_capture_list
=
sorted
(
set
(
batch_size_capture_list
+
mtp_batch_size_capture_list
))
batch_size_capture_list
=
[
i
for
i
in
batch_size_capture_list
if
i
==
1
or
i
%
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
==
0
]
batch_size_capture_list
=
[
i
for
i
in
batch_size_capture_list
if
i
==
1
or
i
%
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
==
0
]
if
ep_sp
:
if
ep_sp
or
enable_dp_attention
:
batch_size_capture_list
=
sorted
(
set
([
round_up
(
i
,
tp_size
)
for
i
in
batch_size_capture_list
]))
batch_size_capture_list
=
sorted
(
set
([
round_up
(
i
,
tp_size
)
for
i
in
batch_size_capture_list
]))
if
1
not
in
batch_size_capture_list
:
batch_size_capture_list
.
insert
(
0
,
1
)
else
:
else
:
if
ep_sp
:
if
ep_sp
or
enable_dp_attention
:
batch_size_capture_list
=
sorted
(
set
([
round_up
(
i
,
tp_size
)
for
i
in
batch_size_capture_list
]))
batch_size_capture_list
=
sorted
(
set
([
round_up
(
i
,
tp_size
)
for
i
in
batch_size_capture_list
]))
if
1
not
in
batch_size_capture_list
:
batch_size_capture_list
.
insert
(
0
,
1
)
self
.
compilation_config
.
init_with_cudagraph_sizes
(
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
batch_size_capture_list
)
...
...
vllm/engine/arg_utils.py
View file @
cda54326
...
@@ -476,6 +476,9 @@ class EngineArgs:
...
@@ -476,6 +476,9 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel
:
bool
=
\
enable_multimodal_encoder_data_parallel
:
bool
=
\
ParallelConfig
.
enable_multimodal_encoder_data_parallel
ParallelConfig
.
enable_multimodal_encoder_data_parallel
enable_dp_attention
:
bool
=
\
ParallelConfig
.
enable_dp_attention
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# support `EngineArgs(compilation_config={...})`
...
@@ -718,6 +721,10 @@ class EngineArgs:
...
@@ -718,6 +721,10 @@ class EngineArgs:
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--enable-multimodal-encoder-data-parallel"
,
"--enable-multimodal-encoder-data-parallel"
,
**
parallel_kwargs
[
"enable_multimodal_encoder_data_parallel"
])
**
parallel_kwargs
[
"enable_multimodal_encoder_data_parallel"
])
parallel_group
.
add_argument
(
"--enable-dp-attention"
,
**
parallel_kwargs
[
"enable_dp_attention"
])
# KV cache arguments
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_kwargs
=
get_kwargs
(
CacheConfig
)
...
@@ -1204,6 +1211,7 @@ class EngineArgs:
...
@@ -1204,6 +1211,7 @@ class EngineArgs:
worker_extension_cls
=
self
.
worker_extension_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
enable_multimodal_encoder_data_parallel
=
self
.
enable_multimodal_encoder_data_parallel
=
self
.
enable_multimodal_encoder_data_parallel
,
enable_multimodal_encoder_data_parallel
,
enable_dp_attention
=
self
.
enable_dp_attention
,
)
)
speculative_config
=
self
.
create_speculative_config
(
speculative_config
=
self
.
create_speculative_config
(
...
...
vllm/forward_context.py
View file @
cda54326
...
@@ -210,4 +210,15 @@ def set_profilling(profiling):
...
@@ -210,4 +210,15 @@ def set_profilling(profiling):
def
get_profilling
()
->
bool
:
def
get_profilling
()
->
bool
:
global
_profiling
global
_profiling
return
_profiling
return
_profiling
\ No newline at end of file
_warming_up
=
False
@
contextmanager
def
set_warming_up
(
warming_up
):
global
_warming_up
_warming_up
=
warming_up
def
get_warming_up
()
->
bool
:
global
_warming_up
return
_warming_up
\ No newline at end of file
vllm/model_executor/layers/dp_attention.py
0 → 100644
View file @
cda54326
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
logging
import
torch
import
vllm.envs
as
envs
from
vllm.distributed.parallel_state
import
GroupCoordinator
,
init_model_parallel_group
,
get_world_group
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
get_tensor_model_parallel_rank
,
tensor_model_parallel_reduce_scatter
,
get_tp_group
)
_ENABLE_DP_ATTENTION_FLAG
:
bool
=
False
_MOE_TP
:
Optional
[
GroupCoordinator
]
=
None
_ATTN_DP_SIZE
=
0
_ATTN_TP_SIZE
=
0
_ATTN_TP_RANK
=
0
_ATTN_DP_RANK
=
0
_MOT_TP_SIZE
=
0
_MOT_TP_RANK
=
0
def
initialize_dp_attention
(
vllm_config
,
backend
:
Optional
[
str
]
=
None
):
from
vllm.config
import
VllmConfig
assert
isinstance
(
vllm_config
,
VllmConfig
)
global
_ENABLE_DP_ATTENTION_FLAG
,
_ATTN_DP_SIZE
,
_ATTN_TP_SIZE
,
_ATTN_TP_RANK
,
_ATTN_DP_RANK
,
_MOT_TP_SIZE
,
_MOT_TP_RANK
enable_dp_attention
=
vllm_config
.
parallel_config
.
enable_dp_attention
_ENABLE_DP_ATTENTION_FLAG
=
enable_dp_attention
# Build the moe tensor model-parallel groups.
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
=
vllm_config
.
parallel_config
.
data_parallel_size
pipeline_model_parallel_size
=
vllm_config
.
parallel_config
.
pipeline_parallel_size
tensor_model_parallel_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
moe_tp_size
=
world_size
//
pipeline_model_parallel_size
moe_ep_size
=
moe_tp_size
if
vllm_config
.
parallel_config
.
enable_expert_parallel
else
1
_ATTN_DP_SIZE
=
data_parallel_size
_ATTN_TP_SIZE
=
tensor_model_parallel_size
_ATTN_TP_RANK
=
get_tensor_model_parallel_rank
()
_ATTN_DP_RANK
=
vllm_config
.
parallel_config
.
data_parallel_rank
_MOT_TP_SIZE
=
moe_tp_size
_MOT_TP_RANK
=
rank
%
_MOT_TP_SIZE
global
_MOE_TP
assert
_MOE_TP
is
None
,
(
"moe tensor model parallel group is already initialized"
)
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
group_ranks
=
[]
for
i
in
range
(
pipeline_model_parallel_size
):
ranks
=
list
(
range
(
i
*
moe_tp_size
,
(
i
+
1
)
*
moe_tp_size
)
)
group_ranks
.
append
(
ranks
)
# message queue broadcaster is only used in tensor model parallel group
_MOE_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
,
group_name
=
"moe_tp"
)
def
get_attention_tp_size
()
->
int
:
assert
_ATTN_TP_SIZE
is
not
None
,
"dp attention not initialized!"
return
_ATTN_TP_SIZE
def
get_attention_tp_rank
()
->
int
:
assert
_ATTN_TP_RANK
is
not
None
,
"dp attention not initialized!"
return
_ATTN_TP_RANK
def
get_moe_tp_group
()
->
GroupCoordinator
:
assert
_MOE_TP
is
not
None
,
(
"tensor model parallel group is not initialized"
)
return
_MOE_TP
def
get_attention_dp_size
()
->
int
:
assert
_ATTN_DP_SIZE
is
not
None
,
"dp attention not initialized!"
return
_ATTN_DP_SIZE
def
get_moe_tp_rank
()
->
int
:
assert
_MOT_TP_RANK
is
not
None
,
"dp attention not initialized!"
return
_MOT_TP_RANK
def
get_moe_tp_size
()
->
int
:
assert
_MOT_TP_SIZE
is
not
None
,
"dp attention not initialized!"
return
_MOT_TP_SIZE
def
get_attention_tp_group
()
->
GroupCoordinator
:
return
get_tp_group
()
def
moe_tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
return
get_moe_tp_group
().
all_gather
(
input_
,
dim
)
def
moe_tensor_model_parallel_reduce_scatter
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""Reduce-Scatter the input tensor across model parallel group."""
return
get_moe_tp_group
().
reduce_scatter
(
input_
,
dim
)
def
dp_gather
(
hidden_states
:
torch
.
Tensor
,)
->
torch
.
Tensor
:
if
get_attention_tp_size
()
==
1
:
hidden_states
=
moe_tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
)
return
hidden_states
hidden_states
=
tensor_model_parallel_reduce_scatter
(
hidden_states
,
dim
=
0
)
hidden_states
=
moe_tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
)
return
hidden_states
def
dp_reduce_scatter_tensor
(
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
get_moe_tp_group
().
world_size
==
get_attention_dp_size
():
hidden_states
=
moe_tensor_model_parallel_reduce_scatter
(
hidden_states
,
dim
=
0
)
else
:
hidden_states
=
moe_tensor_model_parallel_reduce_scatter
(
hidden_states
,
dim
=
0
)
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
)
return
hidden_states
vllm/model_executor/layers/fused_moe/layer.py
View file @
cda54326
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
,
is_power_of_two
)
fused_topk
,
grouped_topk
,
is_power_of_two
)
from
vllm.model_executor.layers.dp_attention
import
get_moe_tp_rank
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
...
@@ -175,7 +176,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -175,7 +176,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager
.
world_size
,
all2all_manager
.
world_size
,
)
)
ll_handle
=
all2all_manager
.
get_handle
(
ll_all_to_all_args
)
ll_handle
=
all2all_manager
.
get_handle
(
ll_all_to_all_args
)
# HT prepare/finalize built on the same LL handle per request
# HT prepare/finalize built on the same LL handle per request
ht_prepare_finalize
=
DeepEPHTPrepareAndFinalize
(
ht_prepare_finalize
=
DeepEPHTPrepareAndFinalize
(
ll_handle
,
ll_handle
,
...
@@ -203,10 +204,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -203,10 +204,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize
=
DeepEPAutoPrepareAndFinalize
(
prepare_finalize
=
DeepEPAutoPrepareAndFinalize
(
ht_prepare_finalize
,
ll_prepare_finalize
)
ht_prepare_finalize
,
ll_prepare_finalize
)
experts_ht
=
self
.
select_gemm_impl
(
ht_prepare_finalize
,
moe
)
experts_ht
=
self
.
select_gemm_impl
(
ht_prepare_finalize
,
moe
)
experts_ll
=
self
.
select_gemm_impl
(
ll_prepare_finalize
,
moe
)
experts_ll
=
self
.
select_gemm_impl
(
ll_prepare_finalize
,
moe
)
self
.
topk_indices_dtype
=
ll_prepare_finalize
.
topk_indices_dtype
()
self
.
topk_indices_dtype
=
ll_prepare_finalize
.
topk_indices_dtype
()
self
.
fused_experts
=
DeepGemmDisabledFusedMoEModularKernel
(
self
.
fused_experts
=
DeepGemmDisabledFusedMoEModularKernel
(
prepare_finalize
,
prepare_finalize
,
...
@@ -214,9 +215,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -214,9 +215,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
experts_ht
=
experts_ht
,
experts_ht
=
experts_ht
,
experts_ll
=
experts_ll
,
experts_ll
=
experts_ll
,
shared_experts
=
layer
.
shared_experts
if
hasattr
(
layer
,
"shared_experts"
)
else
None
,
shared_experts
=
layer
.
shared_experts
if
hasattr
(
layer
,
"shared_experts"
)
else
None
,
)
)
return
return
elif
moe
.
use_deepep_ht_kernels
:
elif
moe
.
use_deepep_ht_kernels
:
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
assert
moe
.
dp_size
==
all2all_manager
.
dp_world_size
...
@@ -900,6 +901,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -900,6 +901,8 @@ class FusedMoE(torch.nn.Module):
self
.
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
self
.
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
self
.
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
self
.
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
self
.
enable_dp_attention
=
vllm_config
.
parallel_config
.
enable_dp_attention
# Determine expert maps
# Determine expert maps
if
self
.
use_ep
:
if
self
.
use_ep
:
if
self
.
enable_eplb
:
if
self
.
enable_eplb
:
...
@@ -1620,7 +1623,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -1620,7 +1623,8 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default.
The pplx combine kernel reduces across GPU ranks by default.
"""
"""
if
(
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
if
(
self
.
use_pplx_kernels
or
self
.
use_deepep_ht_kernels
or
self
.
use_deepep_ll_kernels
or
self
.
use_deepep_auto_kernels
):
or
self
.
use_deepep_ll_kernels
or
self
.
use_deepep_auto_kernels
or
self
.
enable_dp_attention
):
return
final_hidden_states
return
final_hidden_states
else
:
else
:
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
@@ -1759,6 +1763,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1759,6 +1763,7 @@ class FusedMoE(torch.nn.Module):
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
and
not
self
.
moe_parallel_config
.
use_deepep_ll_kernels
and
not
self
.
moe_parallel_config
.
use_deepep_ll_kernels
and
not
self
.
moe_parallel_config
.
use_deepep_auto_kernels
and
not
self
.
moe_parallel_config
.
use_deepep_auto_kernels
and
not
self
.
enable_dp_attention
)
)
if
do_naive_dispatch_combine
:
if
do_naive_dispatch_combine
:
hidden_states
,
router_logits
=
get_ep_group
().
dispatch
(
hidden_states
,
router_logits
=
get_ep_group
().
dispatch
(
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
cda54326
...
@@ -22,6 +22,7 @@ from vllm.utils import round_up
...
@@ -22,6 +22,7 @@ from vllm.utils import round_up
try
:
try
:
from
lmslim.layers.gemm.int8_utils
import
(
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_group_quant_int8
,
per_token_quant_int8
)
from
lightop
import
op
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to use int utils.
\n
"
)
print
(
"INFO: Please install lmslim if you want to use int utils.
\n
"
)
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
@@ -622,52 +623,62 @@ def ep_scatter(
...
@@ -622,52 +623,62 @@ def ep_scatter(
num_experts
=
num_recv_tokens_per_expert
.
shape
[
0
]
num_experts
=
num_recv_tokens_per_expert
.
shape
[
0
]
hidden_size
=
recv_x
.
shape
[
1
]
hidden_size
=
recv_x
.
shape
[
1
]
scale_hidden_size
=
recv_x_scale
.
shape
[
-
1
]
scale_hidden_size
=
recv_x_scale
.
shape
[
-
1
]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
if
hasattr
(
op
,
"ep_scatter"
):
op
.
ep_scatter
(
recv_x
,
recv_x_scale
,
recv_topk
,
expert_map
,
num_recv_tokens_per_expert
,
output_tensor
,
output_tensor_scale
,
m_indices
,
output_index
,
num_experts
,
BLOCK_E
)
else
:
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
_fwd_kernel_ep_scatter_1
[(
grid
,)](
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
=
num_experts
,
num_warps
=
num_warps
,
BLOCK_E
=
BLOCK_E
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
)
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
_fwd_kernel_ep_scatter_1
[(
grid
,)](
_fwd_kernel_ep_scatter_2
[(
grid
,)](
num_recv_tokens_per_expert
,
recv_topk
.
shape
[
0
],
expert_start_loc
,
expert_start_loc
,
m_indices
,
recv_x
,
num_experts
=
num_experts
,
recv_x
.
stride
(
0
),
num_warps
=
num_warps
,
recv_x
.
stride
(
1
),
BLOCK_E
=
BLOCK_E
,
recv_x_scale
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
recv_x_scale
.
stride
(
0
),
)
recv_x_scale
.
stride
(
1
),
recv_topk
,
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
recv_topk
.
stride
(
0
),
_fwd_kernel_ep_scatter_2
[(
grid
,)](
recv_topk
.
stride
(
1
),
recv_topk
.
shape
[
0
],
output_tensor
,
expert_start_loc
,
output_tensor
.
stride
(
0
),
recv_x
,
output_tensor
.
stride
(
1
),
recv_x
.
stride
(
0
),
output_tensor_scale
,
recv_x
.
stride
(
1
),
output_tensor_scale
.
stride
(
0
),
recv_x_scale
,
output_tensor_scale
.
stride
(
1
),
recv_x_scale
.
stride
(
0
),
output_index
,
recv_x_scale
.
stride
(
1
),
output_index
.
stride
(
0
),
recv_topk
,
output_index
.
stride
(
1
),
recv_topk
.
stride
(
0
),
topk_num
=
recv_topk
.
shape
[
1
],
recv_topk
.
stride
(
1
),
expert_map
=
expert_map
,
output_tensor
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
output_tensor
.
stride
(
0
),
num_warps
=
num_warps
,
output_tensor
.
stride
(
1
),
HIDDEN_SIZE
=
hidden_size
,
output_tensor_scale
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
output_tensor_scale
.
stride
(
0
),
SCALE_HIDDEN_SIZE
=
scale_hidden_size
,
#hidden_size // BLOCK_D,
output_tensor_scale
.
stride
(
1
),
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_hidden_size
)
#triton.next_power_of_2(hidden_size // BLOCK_D),
output_index
,
)
output_index
.
stride
(
0
),
output_index
.
stride
(
1
),
topk_num
=
recv_topk
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
SCALE_HIDDEN_SIZE
=
scale_hidden_size
,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_hidden_size
),
)
return
return
...
...
vllm/model_executor/layers/linear.py
View file @
cda54326
...
@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
...
@@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
RowvLLMParameter
)
RowvLLMParameter
)
from
vllm.model_executor.layers.dp_attention
import
get_moe_tp_rank
,
get_moe_tp_size
# yapf: enable
# yapf: enable
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
...
@@ -625,12 +626,18 @@ class ColumnParallelLinear(LinearBase):
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
expect_tp_size
:
Optional
[
int
]
=
None
,
enable_dp_attn_moe
:
bool
=
False
,
):
):
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
if
expect_tp_size
is
not
None
:
self
.
expect_tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
self
.
tp_size
=
self
.
expect_tp_size
self
.
tp_size
=
self
.
expect_tp_size
self
.
enable_dp_attn_moe
=
enable_dp_attn_moe
if
enable_dp_attn_moe
:
self
.
tp_size
=
get_moe_tp_size
()
self
.
input_size_per_partition
=
input_size
self
.
input_size_per_partition
=
input_size
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
...
@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -878,6 +885,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
expect_tp_size
:
Optional
[
int
]
=
None
,
enable_dp_attn_moe
:
bool
=
False
,
):
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -888,6 +896,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self
.
expect_tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
self
.
enable_dp_attn_moe
=
enable_dp_attn_moe
if
enable_dp_attn_moe
:
tp_size
=
get_moe_tp_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
=
input_size
,
super
().
__init__
(
input_size
=
input_size
,
output_size
=
sum
(
output_sizes
),
output_size
=
sum
(
output_sizes
),
...
@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -898,7 +910,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
return_bias
=
return_bias
,
return_bias
=
return_bias
,
expect_tp_size
=
expect_tp_size
)
expect_tp_size
=
expect_tp_size
,
enable_dp_attn_moe
=
enable_dp_attn_moe
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -999,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -999,6 +1012,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
self
.
expect_tp_size
is
not
None
and
self
.
expect_tp_size
==
1
:
if
self
.
expect_tp_size
is
not
None
and
self
.
expect_tp_size
==
1
:
tp_rank
=
0
tp_rank
=
0
tp_size
=
1
tp_size
=
1
if
self
.
enable_dp_attn_moe
:
tp_rank
=
get_moe_tp_rank
()
tp_size
=
get_moe_tp_size
()
if
output_dim
is
not
None
:
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
...
@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -1121,6 +1138,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
hasattr
(
param
,
"expect_tp_size"
):
if
hasattr
(
param
,
"expect_tp_size"
):
param
.
expect_tp_size
=
self
.
expect_tp_size
param
.
expect_tp_size
=
self
.
expect_tp_size
if
self
.
enable_dp_attn_moe
and
hasattr
(
param
,
"enable_dp_attn_moe"
):
tp_size
=
get_moe_tp_size
()
param
.
enable_dp_attn_moe
=
self
.
enable_dp_attn_moe
if
isinstance
(
param
,
BlockQuantScaleParameter
):
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
Fp8LinearMethod
,
Fp8MoEMethod
)
...
@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1552,6 +1573,7 @@ class RowParallelLinear(LinearBase):
*
,
*
,
return_bias
:
bool
=
True
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
expect_tp_size
:
Optional
[
int
]
=
None
,
enable_dp_attn_moe
:
bool
=
False
,
):
):
# Divide the weight matrix along the first dimension.
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase):
...
@@ -1560,7 +1582,13 @@ class RowParallelLinear(LinearBase):
if
expect_tp_size
is
not
None
:
if
expect_tp_size
is
not
None
:
self
.
tp_rank
=
0
self
.
tp_rank
=
0
self
.
tp_size
=
1
self
.
tp_size
=
1
self
.
expect_tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
self
.
enable_dp_attn_moe
=
enable_dp_attn_moe
if
enable_dp_attn_moe
:
self
.
tp_rank
=
get_moe_tp_rank
()
self
.
tp_size
=
get_moe_tp_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
self
.
output_partition_sizes
=
[
output_size
]
...
@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase):
...
@@ -1610,6 +1638,11 @@ class RowParallelLinear(LinearBase):
if
self
.
expect_tp_size
is
not
None
:
if
self
.
expect_tp_size
is
not
None
:
tp_rank
=
0
tp_rank
=
0
tp_size
=
1
tp_size
=
1
if
self
.
enable_dp_attn_moe
:
tp_rank
=
get_moe_tp_rank
()
tp_size
=
get_moe_tp_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
...
@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1664,6 +1697,9 @@ class RowParallelLinear(LinearBase):
if
self
.
expect_tp_size
is
not
None
and
hasattr
(
param
,
"expect_tp_size"
):
if
self
.
expect_tp_size
is
not
None
and
hasattr
(
param
,
"expect_tp_size"
):
param
.
expect_tp_size
=
self
.
expect_tp_size
param
.
expect_tp_size
=
self
.
expect_tp_size
if
self
.
enable_dp_attn_moe
is
not
None
and
hasattr
(
param
,
"enable_dp_attn_moe"
):
param
.
enable_dp_attn_moe
=
self
.
enable_dp_attn_moe
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
cda54326
...
@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -59,6 +59,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.layers.dp_attention
import
(
dp_gather
,
dp_reduce_scatter_tensor
,
get_moe_tp_size
,
get_moe_tp_rank
,
get_attention_tp_size
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module):
...
@@ -85,17 +87,22 @@ class DeepseekV2MLP(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
enable_dp_attention
=
vllm_config
.
parallel_config
.
enable_dp_attention
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
enable_dp_attn_moe
=
enable_dp_attention
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
reduce_results
=
reduce_results
if
not
enable_dp_attention
else
False
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
prefix
=
f
"
{
prefix
}
.down_proj"
,
enable_dp_attn_moe
=
enable_dp_attention
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -979,6 +986,8 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -979,6 +986,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
config
=
config
self
.
config
=
config
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
enable_dp_attention
=
vllm_config
.
parallel_config
.
enable_dp_attention
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
if
(
config
.
n_routed_experts
is
not
None
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
>=
config
.
first_k_dense_replace
...
@@ -1006,6 +1015,9 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1006,6 +1015,9 @@ class DeepseekV2DecoderLayer(nn.Module):
DeepseekV2MoE
)
and
self
.
use_deepep
and
\
DeepseekV2MoE
)
and
self
.
use_deepep
and
\
self
.
tp_size
>
1
and
not
self
.
is_mtp_layer
:
self
.
tp_size
>
1
and
not
self
.
is_mtp_layer
:
reduce_results
=
False
reduce_results
=
False
else
:
if
self
.
enable_dp_attention
:
reduce_results
=
False
if
model_config
.
use_mla
:
if
model_config
.
use_mla
:
attn_cls
=
DeepseekV2MLAAttention
attn_cls
=
DeepseekV2MLAAttention
...
@@ -1176,6 +1188,10 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1176,6 +1188,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
tensor_model_parallel_reduce_scatter
(
hidden_states
,
dim
=
0
)
hidden_states
=
tensor_model_parallel_reduce_scatter
(
hidden_states
,
dim
=
0
)
if
self
.
enable_dp_attention
:
if
self
.
tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
=
dp_gather
(
hidden_states
)
if
hidden_states
.
dtype
==
torch
.
float16
:
if
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Fix FP16 overflow
...
@@ -1188,8 +1204,14 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1188,8 +1204,14 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
if
not
self
.
enable_dp_attention
:
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
num_tokens
=
hidden_states
.
shape
[
0
]
new_bs
=
num_tokens
//
get_moe_tp_size
()
*
get_attention_tp_size
()
residual
=
hidden_states
[
self
.
dp_rank
*
new_bs
:
(
self
.
dp_rank
+
1
)
*
new_bs
,
:]
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
if
self
.
is_mtp_layer
:
if
self
.
is_mtp_layer
:
if
isinstance
(
self
.
mlp
,
if
isinstance
(
self
.
mlp
,
...
@@ -1201,9 +1223,11 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1201,9 +1223,11 @@ class DeepseekV2DecoderLayer(nn.Module):
new_bs
=
(
ori_bs
+
pad_size
)
//
self
.
tp_size
new_bs
=
(
ori_bs
+
pad_size
)
//
self
.
tp_size
hidden_states
=
hidden_states
[
self
.
tp_rank
*
new_bs
:
(
self
.
tp_rank
+
1
)
*
new_bs
,
:].
contiguous
()
hidden_states
=
hidden_states
[
self
.
tp_rank
*
new_bs
:
(
self
.
tp_rank
+
1
)
*
new_bs
,
:].
contiguous
()
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
self
.
enable_dp_attention
:
hidden_states
=
dp_reduce_scatter_tensor
(
hidden_states
)
if
self
.
is_mtp_layer
:
if
self
.
is_mtp_layer
:
if
isinstance
(
self
.
mlp
,
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
...
...
vllm/model_executor/parameter.py
View file @
cda54326
...
@@ -10,6 +10,7 @@ from torch.nn import Parameter
...
@@ -10,6 +10,7 @@ from torch.nn import Parameter
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.utils
import
_make_synced_weight_loader
from
vllm.model_executor.utils
import
_make_synced_weight_loader
from
vllm.model_executor.layers.dp_attention
import
get_moe_tp_rank
,
get_moe_tp_size
__all__
=
[
__all__
=
[
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
...
@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -97,6 +98,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self
.
_output_dim
=
output_dim
self
.
_output_dim
=
output_dim
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
expect_tp_size
=
-
1
self
.
expect_tp_size
=
-
1
self
.
enable_dp_attn_moe
=
False
@
property
@
property
...
@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -107,6 +109,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
self
.
expect_tp_size
==
1
:
if
self
.
expect_tp_size
==
1
:
tp_rank
=
0
tp_rank
=
0
if
self
.
enable_dp_attn_moe
:
tp_rank
=
get_moe_tp_rank
()
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
tp_rank
*
shard_size
,
shard_size
)
tp_rank
*
shard_size
,
shard_size
)
...
@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -129,6 +135,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
self
.
expect_tp_size
==
1
:
if
self
.
expect_tp_size
==
1
:
tp_rank
=
0
tp_rank
=
0
if
self
.
enable_dp_attn_moe
:
tp_rank
=
get_moe_tp_rank
()
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
...
@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
...
@@ -174,6 +184,7 @@ class RowvLLMParameter(BasevLLMParameter):
self
.
_input_dim
=
input_dim
self
.
_input_dim
=
input_dim
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
expect_tp_size
=
-
1
self
.
expect_tp_size
=
-
1
self
.
enable_dp_attn_moe
=
False
@
property
@
property
def
input_dim
(
self
):
def
input_dim
(
self
):
...
@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
...
@@ -183,6 +194,9 @@ class RowvLLMParameter(BasevLLMParameter):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
self
.
expect_tp_size
==
1
:
if
self
.
expect_tp_size
==
1
:
tp_rank
=
0
tp_rank
=
0
if
self
.
enable_dp_attn_moe
:
tp_rank
=
get_moe_tp_rank
()
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
tp_rank
*
shard_size
,
shard_size
)
tp_rank
*
shard_size
,
shard_size
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
cda54326
...
@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
...
@@ -12,7 +12,7 @@ from vllm.attention.layer import Attention
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
get_layers_from_vllm_config
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
DPMetadata
,
set_forward_context
from
vllm.forward_context
import
DPMetadata
,
set_forward_context
,
get_warming_up
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
...
@@ -93,6 +93,8 @@ class EagleProposer:
...
@@ -93,6 +93,8 @@ class EagleProposer:
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
self
.
enable_expert_parallel
=
vllm_config
.
parallel_config
.
enable_expert_parallel
self
.
enable_expert_parallel
=
vllm_config
.
parallel_config
.
enable_expert_parallel
self
.
enable_dp_attention
=
vllm_config
.
parallel_config
.
enable_dp_attention
self
.
attn_tp_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
def
propose
(
def
propose
(
self
,
self
,
...
@@ -189,7 +191,8 @@ class EagleProposer:
...
@@ -189,7 +191,8 @@ class EagleProposer:
else
:
else
:
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
if
self
.
enable_dp_attention
:
num_input_tokens
=
round_up
(
num_input_tokens
,
self
.
attn_tp_size
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
num_input_tokens
+=
num_pad
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
...
@@ -279,6 +282,13 @@ class EagleProposer:
...
@@ -279,6 +282,13 @@ class EagleProposer:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
else
:
input_batch_size
=
batch_size
input_batch_size
=
batch_size
# dp attention need all dp rank process same number tokens
if
self
.
enable_dp_attention
:
input_batch_size
=
round_up
(
input_batch_size
,
self
.
attn_tp_size
)
num_pad
,
_
=
self
.
get_dp_padding
(
input_batch_size
)
input_batch_size
+=
num_pad
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
...
@@ -373,6 +383,7 @@ class EagleProposer:
...
@@ -373,6 +383,7 @@ class EagleProposer:
attn_metadata
.
num_decode_tokens
)
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
attn_metadata
.
decode
.
seq_lens
)
...
@@ -531,10 +542,9 @@ class EagleProposer:
...
@@ -531,10 +542,9 @@ class EagleProposer:
#
#
# TODO(tms) : There are many cases where padding is enabled for
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
envs
.
VLLM_ALL2ALL_BACKEND
!=
'naive'
:
if
not
self
.
enable_dp_attention
and
not
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
:
# auto
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
envs
.
VLLM_ALL2ALL_BACKEND
!=
'naive'
:
if
not
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
:
# Early exit.
# Early exit.
return
0
,
None
return
0
,
None
...
@@ -566,7 +576,7 @@ class EagleProposer:
...
@@ -566,7 +576,7 @@ class EagleProposer:
# Padding for DP
# Padding for DP
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_tokens
)
num_pad
,
_
=
self
.
get_dp_padding
(
num_tokens
)
num_input_tokens
+=
num_pad
num_input_tokens
+=
num_pad
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
...
@@ -578,17 +588,57 @@ class EagleProposer:
...
@@ -578,17 +588,57 @@ class EagleProposer:
self
.
hidden_states
[:
num_input_tokens
],
self
.
hidden_states
[:
num_input_tokens
],
)
)
if
self
.
dp_size
>
1
and
self
.
enable_expert_parallel
and
self
.
num_speculative_tokens
>
1
:
if
self
.
dp_size
>
1
and
(
self
.
enable_expert_parallel
or
self
.
enable_dp_attention
)
and
self
.
num_speculative_tokens
>
1
:
num_token
=
1
num_tokens
=
1
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
# dp attention need all dp rank process same number tokens
with
set_forward_context
(
attn_metadata
,
if
self
.
enable_dp_attention
:
self
.
vllm_config
,
num_tokens
=
round_up
(
num_tokens
,
self
.
attn_tp_size
)
num_tokens
=
num_tokens
):
num_pad
,
_
=
self
.
get_dp_padding
(
num_tokens
)
self
.
model
(
num_tokens
+=
num_pad
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
if
not
get_warming_up
():
self
.
hidden_states
[:
num_tokens
],
common_attn_metadata
=
CommonAttentionMetadata
(
)
query_start_loc
=
self
.
runner
.
query_start_loc
[:
num_tokens
+
1
],
seq_lens
=
self
.
runner
.
seq_lens
[:
num_tokens
],
num_reqs
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
slot_mapping
=
self
.
runner
.
slot_mapping
[:
num_tokens
],
spec_layer_decoding
=
True
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build_for_cudagraph_capture
(
common_attn_metadata
=
common_attn_metadata
)
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
if
self
.
attn_metadata_cudagraph
is
not
None
:
if
i
==
0
:
attn_metadata_cudagraph
=
self
.
attn_metadata_cudagraph
attn_metadata_cudagraph
.
num_actual_tokens
=
num_tokens
attn_metadata_cudagraph
.
num_decodes
=
num_tokens
attn_metadata_cudagraph
.
num_decode_tokens
=
num_tokens
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
num_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
num_tokens
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
num_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
self
.
hidden_states
[:
num_tokens
],
)
def
validate_same_kv_cache_group
(
self
,
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cda54326
...
@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
...
@@ -31,7 +31,7 @@ from vllm.distributed.parallel_state import (
prepare_communication_buffer_for_model
,
prepare_communication_buffer_for_model
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
,
set_profilling
)
set_forward_context
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -339,6 +339,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
self
.
enable_expert_parallel
and
self
.
dp_size
>
1
and
self
.
tp_size
>
1
:
if
self
.
enable_expert_parallel
and
self
.
dp_size
>
1
and
self
.
tp_size
>
1
:
self
.
ep_sp
=
True
self
.
ep_sp
=
True
self
.
enable_dp_attention
=
self
.
parallel_config
.
enable_dp_attention
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
"""
Update the order of requests in the batch based on the attention
Update the order of requests in the batch based on the attention
...
@@ -1275,13 +1277,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1275,13 +1277,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
:
if
not
self
.
enable_dp_attention
and
not
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
:
# auto
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
envs
.
VLLM_ALL2ALL_BACKEND
!=
'naive'
:
if
not
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
:
# Early exit.
# Early exit.
return
0
,
None
return
0
,
None
num_tokens_across_dp
=
DPMetadata
.
num_tokens_across_dp
(
num_tokens_across_dp
=
DPMetadata
.
num_tokens_across_dp
(
num_tokens
,
dp_size
,
dp_rank
)
num_tokens
,
dp_size
,
dp_rank
)
max_tokens_across_dp_cpu
=
torch
.
max
(
num_tokens_across_dp
).
item
()
max_tokens_across_dp_cpu
=
torch
.
max
(
num_tokens_across_dp
).
item
()
...
@@ -1357,7 +1357,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1357,7 +1357,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
if
self
.
ep_sp
or
self
.
enable_dp_attention
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
self
.
tp_size
)
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
self
.
tp_size
)
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
...
@@ -1638,9 +1638,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1638,9 +1638,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Mask out the sampled tokens that should not be sampled.
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
valid_sampled_token_ids
[
i
].
clear
()
if
spec_token_ids
is
not
None
:
for
i
in
discard_sampled_tokens_req_indices
:
spec_token_ids
[
i
].
clear
()
# Cache the sampled tokens in the model runner, so that the scheduler
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# doesn't need to send them back.
...
@@ -1681,6 +1678,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1681,6 +1678,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
attn_metadata
,
attn_metadata
,
)
)
if
spec_token_ids
is
not
None
:
for
i
in
discard_sampled_tokens_req_indices
:
spec_token_ids
[
i
].
clear
()
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
...
@@ -2121,13 +2122,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2121,13 +2122,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
skip_eplb
:
bool
=
False
,
skip_eplb
:
bool
=
False
,
is_profile
:
bool
=
False
,
is_profile
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
if
self
.
ep_sp
or
self
.
enable_dp_attention
:
if
num_tokens
<
self
.
tp_size
:
if
num_tokens
<
self
.
tp_size
:
num_tokens
=
self
.
tp_size
num_tokens
=
self
.
tp_size
# Padding for DP
num_tokens_across_dp
=
0
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_tokens
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_tokens
)
num_tokens
+=
num_pad
num_tokens
+=
num_pad
...
@@ -2148,13 +2148,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2148,13 +2148,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_tokens
//
min_tokens_per_req
num_reqs
=
num_tokens
//
min_tokens_per_req
if
self
.
ep_sp
:
if
self
.
ep_sp
or
self
.
enable_dp_attention
:
num_actual_tokens
=
round_down
(
num_tokens
,
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_actual_tokens
=
round_down
(
num_tokens
,
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_actual_tokens
//
min_tokens_per_req
num_reqs
=
num_actual_tokens
//
min_tokens_per_req
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
if
not
self
.
ep_sp
:
if
not
(
self
.
ep_sp
or
self
.
enable_dp_attention
)
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
else
:
else
:
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
...
@@ -3219,7 +3219,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3219,7 +3219,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
if
self
.
ep_sp
or
self
.
enable_dp_attention
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
self
.
tp_size
)
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
self
.
tp_size
)
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
...
...
vllm/v1/worker/gpu_worker.py
View file @
cda54326
...
@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -17,6 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
set_custom_all_reduce
)
set_custom_all_reduce
)
from
vllm.distributed.kv_transfer
import
ensure_kv_transfer_initialized
from
vllm.distributed.kv_transfer
import
ensure_kv_transfer_initialized
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.model_executor.layers.dp_attention
import
initialize_dp_attention
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
...
@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
...
@@ -30,6 +31,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
from
vllm.forward_context
import
(
set_warming_up
,
get_warming_up
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -260,6 +262,7 @@ class Worker(WorkerBase):
...
@@ -260,6 +262,7 @@ class Worker(WorkerBase):
# warm up sizes that are not in cudagraph capture sizes,
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
# e.g. for the max-num-batched token size in chunked prefill.
set_warming_up
(
True
)
warmup_sizes
=
self
.
vllm_config
.
compilation_config
.
compile_sizes
.
copy
()
warmup_sizes
=
self
.
vllm_config
.
compilation_config
.
compile_sizes
.
copy
()
if
not
self
.
model_config
.
enforce_eager
:
if
not
self
.
model_config
.
enforce_eager
:
warmup_sizes
=
[
warmup_sizes
=
[
...
@@ -297,6 +300,7 @@ class Worker(WorkerBase):
...
@@ -297,6 +300,7 @@ class Worker(WorkerBase):
# Reset the seed to ensure that the random state is not affected by
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
set_warming_up
(
False
)
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
return
self
.
model_runner
.
get_model
()
...
@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
...
@@ -399,6 +403,9 @@ def init_worker_distributed_environment(
ensure_kv_transfer_initialized
(
vllm_config
)
ensure_kv_transfer_initialized
(
vllm_config
)
if
vllm_config
.
parallel_config
.
enable_dp_attention
:
initialize_dp_attention
(
vllm_config
,
backend
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
...
...
vllm/zero_overhead/v1/eagle.py
View file @
cda54326
...
@@ -112,6 +112,8 @@ class V1ZeroEagleProposer(EagleProposer):
...
@@ -112,6 +112,8 @@ class V1ZeroEagleProposer(EagleProposer):
else
:
else
:
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
if
self
.
enable_dp_attention
:
num_input_tokens
=
round_up
(
num_input_tokens
,
self
.
attn_tp_size
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
num_input_tokens
+=
num_pad
...
@@ -202,6 +204,13 @@ class V1ZeroEagleProposer(EagleProposer):
...
@@ -202,6 +204,13 @@ class V1ZeroEagleProposer(EagleProposer):
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
else
:
input_batch_size
=
batch_size
input_batch_size
=
batch_size
# dp attention need all dp rank process same number tokens
if
self
.
enable_dp_attention
:
input_batch_size
=
round_up
(
input_batch_size
,
self
.
attn_tp_size
)
num_pad
,
_
=
self
.
get_dp_padding
(
input_batch_size
)
input_batch_size
+=
num_pad
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment