Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
72c62eae
Unverified
Commit
72c62eae
authored
Mar 05, 2025
by
Tyler Michael Smith
Committed by
GitHub
Mar 04, 2025
Browse files
[V1] EP/TP MoE + DP Attention (#13931)
parent
0a995d54
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
250 additions
and
75 deletions
+250
-75
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+10
-7
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+1
-0
vllm/attention/layer.py
vllm/attention/layer.py
+2
-2
vllm/compilation/backends.py
vllm/compilation/backends.py
+3
-2
vllm/forward_context.py
vllm/forward_context.py
+15
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+155
-32
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+11
-7
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+6
-2
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+15
-6
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+2
-0
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+3
-1
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+4
-1
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+5
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+9
-0
vllm/utils.py
vllm/utils.py
+2
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+0
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-3
No files found.
examples/offline_inference/data_parallel.py
View file @
72c62eae
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# usage:
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
# python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
# ranks. And each rank will create a vLLM instance to process its own prompts.
import
os
import
os
...
@@ -7,6 +9,9 @@ import os
...
@@ -7,6 +9,9 @@ import os
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
GPUs_per_dp_rank
=
2
DP_size
=
2
def
main
(
dp_size
,
dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
def
main
(
dp_size
,
dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
dp_rank
)
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
dp_rank
)
...
@@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
...
@@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
max_tokens
=
16
*
(
dp_rank
+
1
))
max_tokens
=
16
*
(
dp_rank
+
1
))
# Create an LLM.
# Create an LLM.
llm
=
LLM
(
model
=
"
facebook/opt-125m
"
,
llm
=
LLM
(
model
=
"
ibm-research/PowerMoE-3b
"
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
GPUs_per_dp_rank
,
enforce_eager
=
True
)
enforce_eager
=
True
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
...
@@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
...
@@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
from
multiprocessing
import
Process
from
multiprocessing
import
Process
dp_size
=
2
GPUs_per_dp_rank
=
2
dp_master_ip
=
"127.0.0.1"
dp_master_ip
=
"127.0.0.1"
dp_master_port
=
get_open_port
()
dp_master_port
=
get_open_port
()
procs
=
[]
procs
=
[]
for
i
in
range
(
dp
_size
):
for
i
in
range
(
DP
_size
):
proc
=
Process
(
target
=
main
,
proc
=
Process
(
target
=
main
,
args
=
(
dp
_size
,
i
,
dp_master_ip
,
dp_master_port
,
args
=
(
DP
_size
,
i
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
))
GPUs_per_dp_rank
))
proc
.
start
()
proc
.
start
()
procs
.
append
(
proc
)
procs
.
append
(
proc
)
...
...
tests/kernels/test_moe.py
View file @
72c62eae
...
@@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
tp_size
=
1
,
tp_size
=
1
,
dp_size
=
1
,
).
cuda
()
).
cuda
()
# Load the weights
# Load the weights
...
...
vllm/attention/layer.py
View file @
72c62eae
...
@@ -324,7 +324,7 @@ def unified_attention(
...
@@ -324,7 +324,7 @@ def unified_attention(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
attn
_layers
[
layer_name
]
self
=
forward_context
.
no_compile
_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
...
@@ -356,7 +356,7 @@ def unified_attention_with_output(
...
@@ -356,7 +356,7 @@ def unified_attention_with_output(
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
attn
_layers
[
layer_name
]
self
=
forward_context
.
no_compile
_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
self
.
impl
.
forward
(
self
,
query
,
query
,
...
...
vllm/compilation/backends.py
View file @
72c62eae
...
@@ -396,8 +396,9 @@ class VllmBackend:
...
@@ -396,8 +396,9 @@ class VllmBackend:
cache_dir
=
self
.
compilation_config
.
cache_dir
cache_dir
=
self
.
compilation_config
.
cache_dir
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
local_cache_dir
=
os
.
path
.
join
(
rank
=
vllm_config
.
parallel_config
.
rank
cache_dir
,
f
"rank_
{
vllm_config
.
parallel_config
.
rank
}
"
)
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
local_cache_dir
=
os
.
path
.
join
(
cache_dir
,
f
"rank_
{
rank
}
_
{
dp_rank
}
"
)
self
.
compilation_config
.
local_cache_dir
=
local_cache_dir
self
.
compilation_config
.
local_cache_dir
=
local_cache_dir
disable_cache
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
disable_cache
=
envs
.
VLLM_DISABLE_COMPILE_CACHE
...
...
vllm/forward_context.py
View file @
72c62eae
...
@@ -25,16 +25,22 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
...
@@ -25,16 +25,22 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
batchsize_forward_time
:
defaultdict
=
defaultdict
(
list
)
@
dataclass
class
DPMetadata
:
num_tokens_across_dp
:
list
[
int
]
cu_tokens_across_dp_cpu
:
torch
.
Tensor
@
dataclass
@
dataclass
class
ForwardContext
:
class
ForwardContext
:
# copy from vllm_config.compilation_config.static_forward_context
# copy from vllm_config.compilation_config.static_forward_context
attn
_layers
:
dict
[
str
,
Any
]
no_compile
_layers
:
dict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
# TODO: extend to support per-layer dynamic forward context
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
virtual_engine
:
int
# set dynamically for each forward pass
num_tokens_across_dp
:
Optional
[
# set dynamically for each forward pass
list
[
int
]]
=
None
# set dynamically for each forward pass
dp_metadata
:
Optional
[
DPMetadata
]
=
None
_forward_context
:
Optional
[
ForwardContext
]
=
None
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
@@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any,
...
@@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any,
need_to_track_batchsize
=
track_batchsize
and
attn_metadata
is
not
None
need_to_track_batchsize
=
track_batchsize
and
attn_metadata
is
not
None
if
need_to_track_batchsize
:
if
need_to_track_batchsize
:
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
num_tokens_across_dp
=
None
dp_metadata
:
Optional
[
DPMetadata
]
=
None
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
...
@@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any,
...
@@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
num_tokens_across_dp
=
num_tokens_tensor
.
tolist
()
cu_tokens_across_dp_cpu
=
torch
.
cumsum
(
num_tokens_tensor
,
dim
=
0
)
dp_metadata
=
DPMetadata
(
num_tokens_across_dp
,
cu_tokens_across_dp_cpu
)
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
_forward_context
=
ForwardContext
(
_forward_context
=
ForwardContext
(
attn_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
virtual_engine
=
virtual_engine
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
num_tokens_across_dp
=
num_tokens_across_dp
)
dp_metadata
=
dp_metadata
)
try
:
try
:
yield
yield
finally
:
finally
:
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
72c62eae
...
@@ -8,9 +8,11 @@ import torch
...
@@ -8,9 +8,11 @@ import torch
from
torch.nn.parameter
import
UninitializedParameter
from
torch.nn.parameter
import
UninitializedParameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
get_dp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
fused_experts
from
.fused_moe
import
fused_experts
...
@@ -246,6 +249,51 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -246,6 +249,51 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
forward_native
=
forward_cuda
forward_native
=
forward_cuda
def
determine_expert_map
(
ep_size
:
int
,
ep_rank
:
int
,
global_num_experts
:
int
)
->
Tuple
[
int
,
Optional
[
torch
.
Tensor
]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Args:
ep_size (int): The size of the expert parallel group
global_num_experts (int): The total number of experts in the model.
Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
assert
ep_size
>
0
if
ep_size
==
1
:
return
(
global_num_experts
,
None
)
local_num_experts
=
global_num_experts
//
ep_size
# Create a tensor of size num_experts filled with -1
expert_map
=
torch
.
full
((
global_num_experts
,
),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
if
ep_rank
<
(
ep_size
-
1
):
# Each non-last rank gets local_num_experts experts.
expert_map
[
ep_rank
*
local_num_experts
:
(
ep_rank
+
1
)
*
local_num_experts
]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
else
:
# All remaining experts are assigned to the last rank.
local_num_experts
=
(
global_num_experts
-
ep_rank
*
local_num_experts
)
expert_map
[
-
local_num_experts
:]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
return
(
local_num_experts
,
expert_map
)
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
...
@@ -282,6 +330,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -282,6 +330,7 @@ class FusedMoE(torch.nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
dp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
...
@@ -293,16 +342,48 @@ class FusedMoE(torch.nn.Module):
...
@@ -293,16 +342,48 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
# For smuggling this layer into the fused moe custom op
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
use_direct_call
=
not
envs
.
VLLM_TEST_ENABLE_EP
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
tp_rank
=
0
if
self
.
tp_size
==
1
else
get_tensor_model_parallel_rank
()
self
.
dp_size
=
(
dp_size
if
dp_size
is
not
None
else
get_dp_group
().
world_size
)
self
.
dp_rank
=
(
0
if
self
.
dp_size
==
1
else
get_dp_group
().
rank_in_group
)
self
.
global_num_experts
=
num_experts
if
envs
.
VLLM_TEST_ENABLE_EP
:
if
envs
.
VLLM_TEST_ENABLE_EP
:
self
.
ep_size
=
self
.
tp_size
# Set TP size to 1 to adjust for EP and adjust EP size and rank
# for DP attention.
self
.
ep_rank
=
tp_rank
+
self
.
tp_size
*
self
.
dp_rank
self
.
tp_rank
=
0
self
.
ep_size
=
self
.
tp_size
*
self
.
dp_size
self
.
tp_size
=
1
self
.
tp_size
=
1
self
.
local_num_experts
,
self
.
expert_map
=
determine_expert_map
(
ep_size
=
self
.
ep_size
,
ep_rank
=
self
.
ep_rank
,
global_num_experts
=
self
.
global_num_experts
)
else
:
else
:
# Adjust TP size for DP attention
self
.
tp_rank
=
tp_rank
+
self
.
tp_size
*
self
.
dp_rank
self
.
ep_rank
=
0
self
.
tp_size
=
self
.
tp_size
*
self
.
dp_size
self
.
ep_size
=
1
self
.
ep_size
=
1
self
.
local_num_experts
=
self
.
global_num_experts
self
.
expert_map
=
None
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
global_num_experts
=
num_experts
self
.
global_num_experts
=
num_experts
self
.
local_num_experts
=
self
.
global_num_experts
//
self
.
ep_size
assert
intermediate_size
%
self
.
tp_size
==
0
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
...
@@ -316,26 +397,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -316,26 +397,6 @@ class FusedMoE(torch.nn.Module):
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
activation
=
activation
self
.
activation
=
activation
self
.
expert_map
=
None
if
self
.
ep_size
>
1
:
# Create a tensor of size num_experts filled with -1
self
.
expert_map
=
torch
.
full
((
self
.
global_num_experts
,
),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
ep_rank
=
get_tensor_model_parallel_rank
()
if
ep_rank
<
(
self
.
ep_size
-
1
):
# Each non-last rank gets local_num_experts experts.
self
.
expert_map
[
ep_rank
*
self
.
local_num_experts
:
(
ep_rank
+
1
)
*
self
.
local_num_experts
]
=
\
torch
.
arange
(
0
,
self
.
local_num_experts
,
dtype
=
torch
.
int32
)
else
:
# All remaining experts are assigned to the last rank.
self
.
local_num_experts
=
(
self
.
global_num_experts
-
ep_rank
*
self
.
local_num_experts
)
self
.
expert_map
[
-
self
.
local_num_experts
:]
=
\
torch
.
arange
(
0
,
self
.
local_num_experts
,
dtype
=
torch
.
int32
)
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
raise
ValueError
(
"Only softmax scoring function is supported for "
...
@@ -493,9 +554,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -493,9 +554,6 @@ class FusedMoE(torch.nn.Module):
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
return
# TP rank is set to 0 if EP is enabled
tp_rank
=
0
if
self
.
ep_size
>
1
else
get_tensor_model_parallel_rank
()
# compressed-tensors checkpoints with packed weights are stored flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
# against known CompressionFormat enum values that have this quality
...
@@ -539,8 +597,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -539,8 +597,7 @@ class FusedMoE(torch.nn.Module):
final_shape
=
list
(
loaded_weight
.
shape
)
final_shape
=
list
(
loaded_weight
.
shape
)
if
shard_id
in
[
"w1"
,
"w3"
]:
if
shard_id
in
[
"w1"
,
"w3"
]:
final_shape
[
1
]
*=
2
final_shape
[
1
]
*=
2
final_shape
[
shard_dim
]
=
final_shape
[
final_shape
[
shard_dim
]
=
final_shape
[
shard_dim
]
//
self
.
tp_size
shard_dim
]
//
get_tensor_model_parallel_world_size
()
param
.
materialize
(
final_shape
,
dtype
=
loaded_weight
.
dtype
)
param
.
materialize
(
final_shape
,
dtype
=
loaded_weight
.
dtype
)
expert_data
=
param
.
data
if
full_load
else
param
.
data
[
expert_id
]
expert_data
=
param
.
data
if
full_load
else
param
.
data
[
expert_id
]
...
@@ -567,7 +624,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -567,7 +624,7 @@ class FusedMoE(torch.nn.Module):
shard_id
=
shard_id
,
shard_id
=
shard_id
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
tp_rank
=
self
.
tp_rank
)
return
return
# Case weight scales and zero_points
# Case weight scales and zero_points
...
@@ -584,7 +641,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -584,7 +641,7 @@ class FusedMoE(torch.nn.Module):
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
tp_rank
=
self
.
tp_rank
)
elif
quant_method
in
[
elif
quant_method
in
[
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
FusedMoeWeightScaleSupported
.
BLOCK
.
value
,
FusedMoeWeightScaleSupported
.
BLOCK
.
value
,
...
@@ -594,7 +651,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -594,7 +651,7 @@ class FusedMoE(torch.nn.Module):
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
tp_rank
=
self
.
tp_rank
,
load_full_w2
=
getattr
(
param
,
"load_full_w2"
,
False
))
load_full_w2
=
getattr
(
param
,
"load_full_w2"
,
False
))
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
...
@@ -621,7 +678,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -621,7 +678,7 @@ class FusedMoE(torch.nn.Module):
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
tp_rank
=
self
.
tp_rank
)
return
return
@
staticmethod
@
staticmethod
...
@@ -665,10 +722,45 @@ class FusedMoE(torch.nn.Module):
...
@@ -665,10 +722,45 @@ class FusedMoE(torch.nn.Module):
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
naive_multicast
(
self
,
x
:
torch
.
Tensor
,
cu_tokens_across_dp_cpu
:
torch
.
Tensor
):
assert
(
len
(
x
.
shape
)
==
2
)
buffer
=
torch
.
empty
((
cu_tokens_across_dp_cpu
[
-
1
],
x
.
size
(
1
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
start
=
0
if
self
.
dp_rank
==
0
else
cu_tokens_across_dp_cpu
[
self
.
dp_rank
-
1
]
end
=
cu_tokens_across_dp_cpu
[
self
.
dp_rank
]
buffer
[
start
:
end
,
:].
copy_
(
x
)
for
idx
in
range
(
get_dp_group
().
world_size
):
start
=
0
if
idx
==
0
else
cu_tokens_across_dp_cpu
[
idx
-
1
]
end
=
cu_tokens_across_dp_cpu
[
idx
]
get_dp_group
().
broadcast
(
buffer
[
start
:
end
,
:],
idx
)
return
buffer
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
):
if
self
.
use_direct_call
:
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
self
.
dp_size
>
1
:
cu_tokens_across_dp_cpu
=
get_forward_context
(
).
dp_metadata
.
cu_tokens_across_dp_cpu
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_dp_cpu
)
router_logits
=
self
.
naive_multicast
(
router_logits
,
cu_tokens_across_dp_cpu
)
# Matrix multiply.
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
layer
=
self
,
...
@@ -687,6 +779,14 @@ class FusedMoE(torch.nn.Module):
...
@@ -687,6 +779,14 @@ class FusedMoE(torch.nn.Module):
activation
=
self
.
activation
,
activation
=
self
.
activation
,
)
)
if
self
.
dp_size
>
1
:
start
=
0
if
self
.
dp_rank
==
0
else
cu_tokens_across_dp_cpu
[
self
.
dp_rank
-
1
]
end
=
cu_tokens_across_dp_cpu
[
self
.
dp_rank
]
all_hidden_states
=
get_dp_group
().
all_reduce
(
final_hidden_states
)
final_hidden_states
=
all_hidden_states
[
start
:
end
,
:]
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
# Default set to False. (May have to add shared expert outputs.)
# Default set to False. (May have to add shared expert outputs.)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
@@ -757,3 +857,26 @@ class FusedMoE(torch.nn.Module):
...
@@ -757,3 +857,26 @@ class FusedMoE(torch.nn.Module):
s
+=
f
", scoring_func='
{
self
.
scoring_func
}
', activation='
{
self
.
activation
}
'"
# noqa: E501
s
+=
f
", scoring_func='
{
self
.
scoring_func
}
', activation='
{
self
.
activation
}
'"
# noqa: E501
return
s
return
s
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
def
moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"moe_forward"
,
op_func
=
moe_forward
,
mutates_args
=
[],
fake_impl
=
moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/model_executor/models/aria.py
View file @
72c62eae
...
@@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict):
...
@@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict):
pixel_values
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
pixel_mask
:
Optional
[
torch
.
Tensor
]
pixel_mask
:
Optional
[
torch
.
Tensor
]
"""
"""
Shape:
Shape:
pixel_values: `(batch_size * num_images, num_channels, height, width)`
pixel_values: `(batch_size * num_images, num_channels, height, width)`
pixel_mask: `(batch_size * num_images, height, width)`
pixel_mask: `(batch_size * num_images, height, width)`
"""
"""
...
@@ -135,11 +135,11 @@ class AriaProjector(nn.Module):
...
@@ -135,11 +135,11 @@ class AriaProjector(nn.Module):
query numbers,
query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
based on image resolution.
based on image resolution.
embed_dim (int): Embedding dimension.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
Outputs:
Outputs:
...
@@ -239,6 +239,7 @@ class AriaTextMoELayer(nn.Module):
...
@@ -239,6 +239,7 @@ class AriaTextMoELayer(nn.Module):
self
,
self
,
config
:
AriaTextConfig
,
config
:
AriaTextConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -254,6 +255,7 @@ class AriaTextMoELayer(nn.Module):
...
@@ -254,6 +255,7 @@ class AriaTextMoELayer(nn.Module):
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
True
,
reduce_results
=
True
,
prefix
=
f
"
{
prefix
}
.experts"
,
)
)
self
.
shared_experts
=
LlamaMLP
(
self
.
shared_experts
=
LlamaMLP
(
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -301,7 +303,9 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
...
@@ -301,7 +303,9 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
(
config
,
cache_config
,
quant_config
,
prefix
)
super
().
__init__
(
config
,
cache_config
,
quant_config
,
prefix
)
self
.
mlp
=
AriaTextMoELayer
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
AriaTextMoELayer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
class
AriaTextModel
(
LlamaModel
,
SupportsQuant
):
class
AriaTextModel
(
LlamaModel
,
SupportsQuant
):
...
...
vllm/model_executor/models/dbrx.py
View file @
72c62eae
...
@@ -65,6 +65,7 @@ class DbrxExperts(FusedMoE):
...
@@ -65,6 +65,7 @@ class DbrxExperts(FusedMoE):
config
:
DbrxConfig
,
config
:
DbrxConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
(
super
().
__init__
(
num_experts
=
config
.
ffn_config
.
moe_num_experts
,
num_experts
=
config
.
ffn_config
.
moe_num_experts
,
...
@@ -76,6 +77,7 @@ class DbrxExperts(FusedMoE):
...
@@ -76,6 +77,7 @@ class DbrxExperts(FusedMoE):
renormalize
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
get_tensor_model_parallel_world_size
(),
tp_size
=
get_tensor_model_parallel_world_size
(),
prefix
=
prefix
,
)
)
self
.
config
=
config
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -139,6 +141,7 @@ class DbrxMoE(nn.Module):
...
@@ -139,6 +141,7 @@ class DbrxMoE(nn.Module):
config
:
DbrxConfig
,
config
:
DbrxConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
...
@@ -150,7 +153,8 @@ class DbrxMoE(nn.Module):
...
@@ -150,7 +153,8 @@ class DbrxMoE(nn.Module):
self
.
experts
=
DbrxExperts
(
config
=
config
,
self
.
experts
=
DbrxExperts
(
config
=
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
params_dtype
=
self
.
params_dtype
)
params_dtype
=
self
.
params_dtype
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
...
@@ -291,7 +295,7 @@ class DbrxBlock(nn.Module):
...
@@ -291,7 +295,7 @@ class DbrxBlock(nn.Module):
cache_config
,
cache_config
,
quant_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.norm_attn_norm"
)
prefix
=
f
"
{
prefix
}
.norm_attn_norm"
)
self
.
ffn
=
DbrxMoE
(
config
,
quant_config
)
self
.
ffn
=
DbrxMoE
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.ffn"
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/jamba.py
View file @
72c62eae
...
@@ -47,7 +47,8 @@ class JambaMoE(nn.Module):
...
@@ -47,7 +47,8 @@ class JambaMoE(nn.Module):
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
num_total_experts
=
num_experts
or
config
.
num_experts
self
.
num_total_experts
=
num_experts
or
config
.
num_experts
self
.
top_k
=
top_k
or
config
.
num_experts_per_tok
self
.
top_k
=
top_k
or
config
.
num_experts_per_tok
...
@@ -70,7 +71,8 @@ class JambaMoE(nn.Module):
...
@@ -70,7 +71,8 @@ class JambaMoE(nn.Module):
reduce_results
=
True
,
reduce_results
=
True
,
renormalize
=
False
,
renormalize
=
False
,
use_grouped_topk
=
False
,
use_grouped_topk
=
False
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
...
@@ -92,13 +94,15 @@ class JambaMLP(JambaMoE):
...
@@ -92,13 +94,15 @@ class JambaMLP(JambaMoE):
config
:
JambaConfig
,
config
:
JambaConfig
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
config
,
super
().
__init__
(
config
,
num_experts
=
1
,
num_experts
=
1
,
top_k
=
1
,
top_k
=
1
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
class
JambaMambaDecoderLayer
(
nn
.
Module
):
class
JambaMambaDecoderLayer
(
nn
.
Module
):
...
@@ -109,6 +113,7 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -109,6 +113,7 @@ class JambaMambaDecoderLayer(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
is_lora_enabled
:
Optional
[
bool
]
=
False
,
is_lora_enabled
:
Optional
[
bool
]
=
False
,
prefix
:
str
=
""
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -129,7 +134,9 @@ class JambaMambaDecoderLayer(nn.Module):
...
@@ -129,7 +134,9 @@ class JambaMambaDecoderLayer(nn.Module):
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
self
.
feed_forward
=
ffn_layer_class
(
config
,
quant_config
=
quant_config
)
self
.
feed_forward
=
ffn_layer_class
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -211,7 +218,9 @@ class JambaAttentionDecoderLayer(nn.Module):
...
@@ -211,7 +218,9 @@ class JambaAttentionDecoderLayer(nn.Module):
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
num_experts
=
config
.
layers_num_experts
[
layer_idx
]
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
ffn_layer_class
=
JambaMoE
if
num_experts
>
1
else
JambaMLP
self
.
feed_forward
=
ffn_layer_class
(
config
,
quant_config
=
quant_config
)
self
.
feed_forward
=
ffn_layer_class
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.feed_forward"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
pre_ff_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
vllm/model_executor/models/mixtral.py
View file @
72c62eae
...
@@ -71,6 +71,7 @@ class MixtralMoE(nn.Module):
...
@@ -71,6 +71,7 @@ class MixtralMoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
dp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -93,6 +94,7 @@ class MixtralMoE(nn.Module):
...
@@ -93,6 +94,7 @@ class MixtralMoE(nn.Module):
renormalize
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
dp_size
=
dp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/olmoe.py
View file @
72c62eae
...
@@ -80,7 +80,8 @@ class OlmoeMoE(nn.Module):
...
@@ -80,7 +80,8 @@ class OlmoeMoE(nn.Module):
reduce_results
=
True
,
reduce_results
=
True
,
renormalize
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
)
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
# NOTE: hidden_states can have either 1D or 2D shape.
...
@@ -212,6 +213,7 @@ class OlmoeDecoderLayer(nn.Module):
...
@@ -212,6 +213,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
...
...
vllm/model_executor/models/phimoe.py
View file @
72c62eae
...
@@ -249,6 +249,7 @@ class PhiMoE(nn.Module):
...
@@ -249,6 +249,7 @@ class PhiMoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -272,7 +273,8 @@ class PhiMoE(nn.Module):
...
@@ -272,7 +273,8 @@ class PhiMoE(nn.Module):
renormalize
=
False
,
renormalize
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
custom_routing_function
=
phimoe_routing_function
)
custom_routing_function
=
phimoe_routing_function
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
# NOTE: hidden_states can have either 1D or 2D shape.
...
@@ -396,6 +398,7 @@ class PhiMoEDecoderLayer(nn.Module):
...
@@ -396,6 +398,7 @@ class PhiMoEDecoderLayer(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
,
)
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
eps
=
config
.
rms_norm_eps
,
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
72c62eae
...
@@ -100,6 +100,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -100,6 +100,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -115,7 +116,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -115,7 +116,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_experts
,
config
.
num_experts
,
...
@@ -277,7 +279,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
...
@@ -277,7 +279,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
config
.
num_experts
>
0
and
config
.
num_experts
>
0
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
else
:
self
.
mlp
=
Qwen2MoeMLP
(
self
.
mlp
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
...
vllm/platforms/cuda.py
View file @
72c62eae
...
@@ -111,6 +111,7 @@ class CudaPlatformBase(Platform):
...
@@ -111,6 +111,7 @@ class CudaPlatformBase(Platform):
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
vllm_config
.
scheduler_config
scheduler_config
=
vllm_config
.
scheduler_config
compilation_config
=
vllm_config
.
compilation_config
if
parallel_config
.
worker_cls
==
"auto"
:
if
parallel_config
.
worker_cls
==
"auto"
:
if
scheduler_config
.
is_multi_step
:
if
scheduler_config
.
is_multi_step
:
...
@@ -150,6 +151,14 @@ class CudaPlatformBase(Platform):
...
@@ -150,6 +151,14 @@ class CudaPlatformBase(Platform):
"FlashMLA: Forcing kv cache block size to 64 since this"
"FlashMLA: Forcing kv cache block size to 64 since this"
" is currently the only block size supported by the kernel."
)
" is currently the only block size supported by the kernel."
)
if
(
parallel_config
.
data_parallel_size
>
1
and
compilation_config
.
use_cudagraph
):
logger
.
info
(
"Data Parallel: Forcing enforce eager to be True since DP is "
"currently not supported with CUDA Graphs."
)
vllm_config
.
model_config
.
enforce_eager
=
True
compilation_config
.
use_cudagraph
=
False
@
classmethod
@
classmethod
def
get_current_memory_usage
(
cls
,
def
get_current_memory_usage
(
cls
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
device
:
Optional
[
torch
.
types
.
Device
]
=
None
...
...
vllm/utils.py
View file @
72c62eae
...
@@ -2196,8 +2196,8 @@ def bind_kv_cache(
...
@@ -2196,8 +2196,8 @@ def bind_kv_cache(
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.model_executor.models.utils
import
extract_layer_index
layer_need_kv_cache
=
[
layer_need_kv_cache
=
[
layer_name
for
layer_name
in
ctx
layer_name
for
layer_name
in
ctx
if
ctx
[
layer_name
]
.
attn_type
in
(
AttentionType
.
DECODER
,
if
(
hasattr
(
ctx
[
layer_name
]
,
'
attn_type
'
)
and
ctx
[
layer_name
].
attn_type
AttentionType
.
ENCODER_DECODER
)
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
)
)
]
]
layer_index_sorted
=
sorted
(
layer_index_sorted
=
sorted
(
set
(
set
(
...
...
vllm/v1/engine/core.py
View file @
72c62eae
...
@@ -149,7 +149,6 @@ class EngineCore:
...
@@ -149,7 +149,6 @@ class EngineCore:
if
not
self
.
scheduler
.
has_unfinished_requests
():
if
not
self
.
scheduler
.
has_unfinished_requests
():
return
EngineCoreOutputs
(
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
72c62eae
...
@@ -17,6 +17,7 @@ from vllm.distributed.parallel_state import get_pp_group, graph_capture
...
@@ -17,6 +17,7 @@ from vllm.distributed.parallel_state import get_pp_group, graph_capture
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
...
@@ -1357,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1357,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"""
"""
Initialize KV cache based on `kv_cache_config`.
Initialize KV cache based on `kv_cache_config`.
Args:
Args:
kv_cache_config: Configuration for the KV cache, including the KV
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
cache size of each layer
"""
"""
if
len
(
kv_cache_config
.
groups
)
>
1
:
if
len
(
kv_cache_config
.
groups
)
>
1
:
...
@@ -1389,10 +1390,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1389,10 +1390,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
"""
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Attention module in the static forward context.
Returns:
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
format. Layers that do not need KV cache are not included.
"""
"""
...
@@ -1400,6 +1401,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1400,6 +1401,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
block_size
=
self
.
vllm_config
.
cache_config
.
block_size
kv_cache_spec
:
KVCacheSpec
=
{}
kv_cache_spec
:
KVCacheSpec
=
{}
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
for
layer_name
,
attn_module
in
forward_ctx
.
items
():
if
isinstance
(
attn_module
,
FusedMoE
):
continue
# TODO: Support other attention modules, e.g., sliding window,
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
# cross-attention, MLA.
assert
isinstance
(
attn_module
,
Attention
)
assert
isinstance
(
attn_module
,
Attention
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment