Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9fd658a
Unverified
Commit
e9fd658a
authored
Jun 26, 2025
by
Bowen Wang
Committed by
GitHub
Jun 26, 2025
Browse files
[Feature] Expert Parallelism Load Balancer (EPLB) (#18343)
Signed-off-by:
Bowen Wang
<
abmfy@icloud.com
>
parent
07b8fae2
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2201 additions
and
30 deletions
+2201
-30
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+17
-0
tests/distributed/test_eplb_algo.py
tests/distributed/test_eplb_algo.py
+292
-0
tests/distributed/test_eplb_execute.py
tests/distributed/test_eplb_execute.py
+504
-0
tests/models/test_initialization.py
tests/models/test_initialization.py
+10
-2
vllm/config.py
vllm/config.py
+33
-0
vllm/distributed/eplb/__init__.py
vllm/distributed/eplb/__init__.py
+7
-0
vllm/distributed/eplb/eplb_state.py
vllm/distributed/eplb/eplb_state.py
+431
-0
vllm/distributed/eplb/rebalance_algo.py
vllm/distributed/eplb/rebalance_algo.py
+233
-0
vllm/distributed/eplb/rebalance_execute.py
vllm/distributed/eplb/rebalance_execute.py
+306
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+20
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+236
-28
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+8
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+42
-0
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+8
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+14
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+8
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+8
-0
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+8
-0
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+8
-0
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+8
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
e9fd658a
...
@@ -168,6 +168,23 @@ steps:
...
@@ -168,6 +168,23 @@ steps:
-
VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
popd
-
popd
-
label
:
EPLB Algorithm Test
working_dir
:
"
/vllm-workspace/tests"
source_file_dependencies
:
-
vllm/distributed/eplb
-
tests/distributed/test_eplb_algo.py
commands
:
-
pytest -v -s distributed/test_eplb_algo.py
-
label
:
EPLB Execution Test
# 5min
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
4
source_file_dependencies
:
-
vllm/distributed/eplb
-
tests/distributed/test_eplb_execute.py
commands
:
-
pytest -v -s distributed/test_eplb_execute.py
-
label
:
Metrics, Tracing Test
# 10min
-
label
:
Metrics, Tracing Test
# 10min
mirror_hardwares
:
[
amdexperimental
,
amdproduction
]
mirror_hardwares
:
[
amdexperimental
,
amdproduction
]
num_gpus
:
2
num_gpus
:
2
...
...
tests/distributed/test_eplb_algo.py
0 → 100644
View file @
e9fd658a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.distributed.eplb.rebalance_algo
import
rebalance_experts
def
test_basic_rebalance
():
"""Test basic rebalancing functionality"""
# Example from https://github.com/deepseek-ai/eplb
weight
=
torch
.
tensor
([
[
90
,
132
,
40
,
61
,
104
,
165
,
39
,
4
,
73
,
56
,
183
,
86
],
[
20
,
107
,
104
,
64
,
19
,
197
,
187
,
157
,
172
,
86
,
16
,
27
],
])
num_layers
=
weight
.
shape
[
0
]
num_replicas
=
16
num_groups
=
4
num_nodes
=
2
num_gpus
=
8
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Verify output shapes
assert
phy2log
.
shape
==
(
2
,
16
,
),
f
"Expected `phy2log` shape (2, 16), got
{
phy2log
.
shape
}
"
assert
(
log2phy
.
shape
[
0
]
==
2
),
f
"Expected `log2phy` first dimension 2, got
{
log2phy
.
shape
[
0
]
}
"
assert
(
log2phy
.
shape
[
1
]
==
12
),
f
"Expected `log2phy` second dimension 12, got
{
log2phy
.
shape
[
1
]
}
"
assert
logcnt
.
shape
==
(
2
,
12
,
),
f
"Expected `logcnt` shape (2, 12), got
{
logcnt
.
shape
}
"
# Verify physical to logical expert mapping range is correct
assert
torch
.
all
(
phy2log
>=
0
)
and
torch
.
all
(
phy2log
<
12
),
"Physical to logical mapping should be in range [0, 12)"
# Verify expert count reasonableness
assert
torch
.
all
(
logcnt
>=
1
),
"Each logical expert should have at least 1 replica"
assert
(
torch
.
sum
(
logcnt
,
dim
=
1
).
sum
()
==
num_replicas
*
num_layers
),
f
"Total replicas should be
{
num_replicas
*
num_layers
}
"
# Verify expected output
expected_phy2log
=
torch
.
tensor
([
[
5
,
6
,
5
,
7
,
8
,
4
,
3
,
4
,
10
,
9
,
10
,
2
,
0
,
1
,
11
,
1
],
[
7
,
10
,
6
,
8
,
6
,
11
,
8
,
9
,
2
,
4
,
5
,
1
,
5
,
0
,
3
,
1
],
])
assert
torch
.
all
(
phy2log
==
expected_phy2log
)
expected_logcnt
=
torch
.
tensor
([[
1
,
2
,
1
,
1
,
2
,
2
,
1
,
1
,
1
,
1
,
2
,
1
],
[
1
,
2
,
1
,
1
,
1
,
2
,
2
,
1
,
2
,
1
,
1
,
1
]])
assert
torch
.
all
(
logcnt
==
expected_logcnt
)
def
test_single_gpu_case
():
"""Test single GPU case"""
weight
=
torch
.
tensor
([[
10
,
20
,
30
,
40
]])
num_replicas
=
4
num_groups
=
1
num_nodes
=
1
num_gpus
=
1
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Verify shapes
assert
phy2log
.
shape
==
(
1
,
4
)
assert
log2phy
.
shape
[
0
]
==
1
assert
log2phy
.
shape
[
1
]
==
4
assert
logcnt
.
shape
==
(
1
,
4
)
# Verify all logical experts are mapped
assert
set
(
phy2log
[
0
].
tolist
())
==
{
0
,
1
,
2
,
3
}
def
test_equal_weights
():
"""Test case with equal weights"""
weight
=
torch
.
tensor
([[
50
,
50
,
50
,
50
,
50
,
50
,
50
,
50
]])
num_replicas
=
8
num_groups
=
2
num_nodes
=
2
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Verify shapes
assert
phy2log
.
shape
==
(
1
,
8
)
assert
logcnt
.
shape
==
(
1
,
8
)
# With equal weights, each expert should have exactly one replica
assert
torch
.
all
(
logcnt
==
1
),
"With equal weights and no replication, "
\
"each expert should have exactly 1 replica"
def
test_extreme_weight_imbalance
():
"""Test extreme weight imbalance case"""
weight
=
torch
.
tensor
([[
1000
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]])
num_replicas
=
12
num_groups
=
2
num_nodes
=
2
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Verify shapes
assert
phy2log
.
shape
==
(
1
,
12
)
assert
logcnt
.
shape
==
(
1
,
8
)
# Expert with highest weight (index 0) should have more replicas
assert
(
logcnt
[
0
,
0
]
>
logcnt
[
0
,
1
]),
"Expert with highest weight should have more replicas"
def
test_multiple_layers
():
"""Test multiple layers case"""
weight
=
torch
.
tensor
([
[
10
,
20
,
30
,
40
,
50
,
60
],
# First layer
[
60
,
50
,
40
,
30
,
20
,
10
],
# Second layer (opposite weight pattern)
[
25
,
25
,
25
,
25
,
25
,
25
],
# Third layer (equal weights)
])
num_replicas
=
8
num_groups
=
2
num_nodes
=
2
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Verify shapes
assert
phy2log
.
shape
==
(
3
,
8
)
assert
logcnt
.
shape
==
(
3
,
6
)
# Verify expert allocation is reasonable for each layer
for
layer
in
range
(
3
):
assert
torch
.
all
(
phy2log
[
layer
]
>=
0
)
and
torch
.
all
(
phy2log
[
layer
]
<
6
),
f
"Layer
{
layer
}
physical to logical mapping"
\
"should be in range [0, 6)"
assert
(
torch
.
sum
(
logcnt
[
layer
])
==
num_replicas
),
f
"Layer
{
layer
}
total replicas should be
{
num_replicas
}
"
def
test_parameter_validation
():
"""Test parameter validation"""
weight
=
torch
.
tensor
([[
10
,
20
,
30
,
40
]])
# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# strategy
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
8
,
3
,
2
,
4
)
assert
phy2log
.
shape
==
(
1
,
8
)
assert
logcnt
.
shape
==
(
1
,
4
)
# Test cases that will actually cause errors:
# num_physical_experts not divisible by num_gpus
with
pytest
.
raises
(
AssertionError
):
rebalance_experts
(
weight
,
7
,
2
,
2
,
4
)
# 7 not divisible by 4
def
test_small_scale_hierarchical
():
"""Test small-scale hierarchical load balancing"""
weight
=
torch
.
tensor
([
[
100
,
50
,
200
,
75
,
150
,
25
,
300
,
80
],
# 8 experts
])
num_replicas
=
12
num_groups
=
4
# 4 groups, 2 experts each
num_nodes
=
2
# 2 nodes
num_gpus
=
4
# 4 GPUs
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Verify basic constraints
assert
phy2log
.
shape
==
(
1
,
12
)
assert
logcnt
.
shape
==
(
1
,
8
)
assert
torch
.
sum
(
logcnt
)
==
num_replicas
assert
torch
.
all
(
logcnt
>=
1
)
# Expert with highest weight should have more replicas
max_weight_expert
=
torch
.
argmax
(
weight
[
0
])
assert
(
logcnt
[
0
,
max_weight_expert
]
>=
2
),
"Highest weight expert should have multiple replicas"
def
test_global_load_balance_fallback
():
"""Test global load balancing fallback case"""
# When num_groups % num_nodes != 0, should fall back to global load
# balancing
weight
=
torch
.
tensor
([[
10
,
20
,
30
,
40
,
50
,
60
]])
num_replicas
=
8
num_groups
=
3
# Cannot be divided evenly by num_nodes=2
num_nodes
=
2
num_gpus
=
4
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Should work normally, just using global load balancing strategy
assert
phy2log
.
shape
==
(
1
,
8
)
assert
logcnt
.
shape
==
(
1
,
6
)
assert
torch
.
sum
(
logcnt
)
==
num_replicas
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
def
test_device_compatibility
(
device
):
"""Test device compatibility"""
if
device
==
"cuda"
and
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
weight
=
torch
.
tensor
([[
10
,
20
,
30
,
40
]],
device
=
device
)
num_replicas
=
6
num_groups
=
2
num_nodes
=
1
num_gpus
=
2
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
# Function will convert to CPU internally, but should handle different
# device inputs normally
assert
phy2log
.
shape
==
(
1
,
6
)
assert
logcnt
.
shape
==
(
1
,
4
)
def
test_additional_cases
():
"""Test more edge cases and different parameter combinations"""
# Test case 1: Large-scale distributed setup
weight1
=
torch
.
tensor
(
[[
50
,
100
,
75
,
120
,
90
,
60
,
80
,
110
,
40
,
70
,
95
,
85
,
65
,
55
,
45
,
35
]])
phy2log1
,
log2phy1
,
logcnt1
=
rebalance_experts
(
weight1
,
24
,
8
,
4
,
8
)
assert
phy2log1
.
shape
==
(
1
,
24
)
assert
logcnt1
.
shape
==
(
1
,
16
)
assert
torch
.
sum
(
logcnt1
)
==
24
# Test case 2: Different weight distributions
weight2
=
torch
.
tensor
([
[
200
,
150
,
100
,
50
,
25
,
12
],
# Decreasing weights
[
12
,
25
,
50
,
100
,
150
,
200
],
# Increasing weights
])
phy2log2
,
log2phy2
,
logcnt2
=
rebalance_experts
(
weight2
,
10
,
3
,
1
,
2
)
assert
phy2log2
.
shape
==
(
2
,
10
)
assert
logcnt2
.
shape
==
(
2
,
6
)
# Verify high-weight experts have more replicas
for
layer
in
range
(
2
):
max_weight_idx
=
torch
.
argmax
(
weight2
[
layer
])
assert
logcnt2
[
layer
,
max_weight_idx
]
>=
2
if
__name__
==
"__main__"
:
weight
=
torch
.
tensor
([
[
90
,
132
,
40
,
61
,
104
,
165
,
39
,
4
,
73
,
56
,
183
,
86
],
[
20
,
107
,
104
,
64
,
19
,
197
,
187
,
157
,
172
,
86
,
16
,
27
],
])
num_replicas
=
16
num_groups
=
4
num_nodes
=
2
num_gpus
=
8
phy2log
,
log2phy
,
logcnt
=
rebalance_experts
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
print
(
phy2log
)
test_basic_rebalance
()
tests/distributed/test_eplb_execute.py
0 → 100644
View file @
e9fd658a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
multiprocessing
import
os
import
random
import
pytest
import
torch
import
torch.distributed
from
vllm.distributed.eplb.rebalance_execute
import
(
rearrange_expert_weights_inplace
)
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
get_tp_group
,
init_distributed_environment
)
from
vllm.utils
import
update_environment_variables
def
distributed_run
(
fn
,
world_size
):
number_of_processes
=
world_size
processes
:
list
[
multiprocessing
.
Process
]
=
[]
for
i
in
range
(
number_of_processes
):
env
:
dict
[
str
,
str
]
=
{}
env
[
'RANK'
]
=
str
(
i
)
env
[
'LOCAL_RANK'
]
=
str
(
i
)
env
[
'WORLD_SIZE'
]
=
str
(
number_of_processes
)
env
[
'LOCAL_WORLD_SIZE'
]
=
str
(
number_of_processes
)
env
[
'MASTER_ADDR'
]
=
'localhost'
env
[
'MASTER_PORT'
]
=
'12345'
p
=
multiprocessing
.
Process
(
target
=
fn
,
args
=
(
env
,
))
processes
.
append
(
p
)
p
.
start
()
for
p
in
processes
:
p
.
join
()
for
p
in
processes
:
assert
p
.
exitcode
==
0
def
worker_fn_wrapper
(
fn
):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def
wrapped_fn
(
env
):
update_environment_variables
(
env
)
local_rank
=
os
.
environ
[
'LOCAL_RANK'
]
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
init_distributed_environment
()
# Ensure each worker process has the same random seed
random
.
seed
(
42
)
torch
.
manual_seed
(
42
)
fn
()
return
wrapped_fn
def
create_expert_indices_with_redundancy
(
num_layers
:
int
,
num_logical_experts
:
int
,
total_physical_experts
:
int
,
redundancy_config
:
list
[
int
],
# redundancy for each logical expert
)
->
torch
.
Tensor
:
"""
Create expert indices with redundancy.
Args:
num_layers: number of layers
num_logical_experts: number of logical experts
total_physical_experts: total number of physical experts
redundancy_config: redundancy for each logical expert
Returns:
indices: Shape (num_layers, total_physical_experts)
"""
assert
sum
(
redundancy_config
)
==
total_physical_experts
assert
len
(
redundancy_config
)
==
num_logical_experts
indices
=
torch
.
zeros
(
num_layers
,
total_physical_experts
,
dtype
=
torch
.
long
)
for
layer
in
range
(
num_layers
):
physical_pos
=
0
for
logical_expert_id
,
redundancy
in
enumerate
(
redundancy_config
):
for
_
in
range
(
redundancy
):
indices
[
layer
,
physical_pos
]
=
logical_expert_id
physical_pos
+=
1
# Shuffle the indices at dim 1
for
layer
in
range
(
num_layers
):
indices
[
layer
]
=
indices
[
layer
][
torch
.
randperm
(
indices
.
shape
[
1
])]
return
indices
def
create_expert_weights
(
num_layers
:
int
,
num_local_experts
:
int
,
hidden_sizes
:
list
[
int
],
rank
:
int
,
device
:
torch
.
device
,
physical_to_logical_mapping
:
torch
.
Tensor
,
)
->
list
[
list
[
torch
.
Tensor
]]:
"""
Create fake expert weights tensor for testing.
Use `arange` to generate predictable weights values, based on logical
expert ID.
All replicas of the same logical expert should have the same weights.
Args:
physical_to_logical_mapping: Shape (num_layers, num_local_experts)
mapping[layer, physical_pos] = logical_expert_id
"""
expert_weights
=
[]
for
layer
in
range
(
num_layers
):
layer_weights
=
[]
for
weight_idx
,
hidden_size
in
enumerate
(
hidden_sizes
):
weight_tensor
=
torch
.
zeros
(
num_local_experts
,
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
for
local_expert
in
range
(
num_local_experts
):
# Get the logical expert ID for this physical expert
global_pos
=
rank
*
num_local_experts
+
local_expert
logical_expert_id
=
physical_to_logical_mapping
[
layer
,
global_pos
].
item
()
# Generate weights based on logical expert ID
# (so that all replicas of the same logical expert have the
# same weights)
base_value
=
(
logical_expert_id
*
1000
+
layer
*
100
+
weight_idx
*
10
)
weight_tensor
[
local_expert
]
=
torch
.
arange
(
base_value
,
base_value
+
hidden_size
,
device
=
device
,
dtype
=
torch
.
float32
)
layer_weights
.
append
(
weight_tensor
)
expert_weights
.
append
(
layer_weights
)
return
expert_weights
def
create_redundancy_config
(
num_logical_experts
:
int
,
num_physical_experts
:
int
,
)
->
list
[
int
]:
"""Create a redundancy configuration."""
redundancy_config
=
[
1
]
*
num_logical_experts
remaining
=
num_physical_experts
-
num_logical_experts
# Randomly assign the remaining physical experts to the logical experts
for
_
in
range
(
remaining
):
redundancy_config
[
random
.
choice
(
range
(
num_logical_experts
))]
+=
1
return
redundancy_config
def
verify_expert_weights_after_shuffle
(
expert_weights
:
list
[
list
[
torch
.
Tensor
]],
new_indices
:
torch
.
Tensor
,
hidden_sizes
:
list
[
int
],
ep_rank
:
int
,
num_local_experts
:
int
,
):
"""Verify the weights after shuffling are correct."""
num_layers
=
len
(
expert_weights
)
for
layer
in
range
(
num_layers
):
for
weight_idx
,
hidden_size
in
enumerate
(
hidden_sizes
):
weight_tensor
=
expert_weights
[
layer
][
weight_idx
]
for
local_expert
in
range
(
num_local_experts
):
# Calculate the global expert ID for this local expert
global_pos
=
ep_rank
*
num_local_experts
+
local_expert
expected_logical_expert
=
new_indices
[
layer
,
global_pos
].
item
()
# Check if the weights are correct
actual_weights
=
weight_tensor
[
local_expert
]
expected_base
=
(
expected_logical_expert
*
1000
+
layer
*
100
+
weight_idx
*
10
)
expected_weights
=
torch
.
arange
(
expected_base
,
expected_base
+
hidden_size
,
device
=
actual_weights
.
device
,
dtype
=
actual_weights
.
dtype
)
torch
.
testing
.
assert_close
(
actual_weights
,
expected_weights
,
msg
=
f
"Layer
{
layer
}
, weight
{
weight_idx
}
,"
f
"local expert
{
local_expert
}
: "
f
"weights do not match. "
f
"Expected logical expert
{
expected_logical_expert
}
"
)
def
verify_redundant_experts_have_same_weights
(
expert_weights
:
list
[
list
[
torch
.
Tensor
]],
indices
:
torch
.
Tensor
,
hidden_sizes
:
list
[
int
],
world_size
:
int
,
num_local_experts
:
int
,
):
"""
Verify that all replicas of the same logical expert have the same weights.
"""
num_layers
=
len
(
expert_weights
)
total_physical_experts
=
world_size
*
num_local_experts
for
layer
in
range
(
num_layers
):
# Collect weights for all physical experts for each weight matrix
all_weights
:
list
[
torch
.
Tensor
]
=
[]
for
weight_idx
,
hidden_size
in
enumerate
(
hidden_sizes
):
# Create tensor to store all expert weights
# Shape: [total_physical_experts, hidden_size]
gathered_weights
=
torch
.
zeros
(
total_physical_experts
,
hidden_size
,
device
=
expert_weights
[
layer
][
weight_idx
].
device
,
dtype
=
expert_weights
[
layer
][
weight_idx
].
dtype
)
# Use all_gather to collect expert weights from current node
# expert_weights[layer][weight_idx] shape:
# [num_local_experts, hidden_size]
local_weights
=
expert_weights
[
layer
][
weight_idx
]
# [num_local_experts, hidden_size]
# Split tensor along dim 0 into a list for all_gather
gathered_weights_list
=
torch
.
chunk
(
gathered_weights
,
world_size
,
dim
=
0
)
torch
.
distributed
.
all_gather
(
# Output list: each element corresponds to one rank's weights
list
(
gathered_weights_list
),
local_weights
# Input: current rank's local weights
)
all_weights
.
append
(
gathered_weights
)
# Verify that all replicas of the same logical expert have the same
# weights
logical_expert_weights
:
dict
[
int
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
for
physical_pos
in
range
(
total_physical_experts
):
logical_expert_id
=
int
(
indices
[
layer
,
physical_pos
].
item
())
if
logical_expert_id
not
in
logical_expert_weights
:
# First time encountering this logical expert, save its weights
logical_expert_weights
[
logical_expert_id
]
=
{
weight_idx
:
all_weights
[
weight_idx
][
physical_pos
]
for
weight_idx
in
range
(
len
(
hidden_sizes
))
}
else
:
# Verify that current physical expert's weights match the
# previously saved logical expert weights
for
weight_idx
in
range
(
len
(
hidden_sizes
)):
torch
.
testing
.
assert_close
(
all_weights
[
weight_idx
][
physical_pos
],
logical_expert_weights
[
logical_expert_id
][
weight_idx
],
msg
=
f
"Layer
{
layer
}
, weight
{
weight_idx
}
,"
f
"logical expert
{
logical_expert_id
}
: "
f
"Physical expert
{
physical_pos
}
has different weights"
f
"than expected"
)
@
pytest
.
mark
.
parametrize
(
"world_size,num_layers,num_local_experts,num_logical_experts"
,
[
# 2 GPU, 2 experts per GPU
# 3 logical experts, 4 physical experts, 1 redundant experts
(
2
,
1
,
2
,
3
),
# 2 GPU, 3 experts per GPU
# 4 logical experts, 6 physical experts, 2 redundant experts
(
2
,
2
,
3
,
4
),
# 2 GPU, 8 experts per GPU
# 16 logical experts, 16 physical experts, 0 redundant experts
(
2
,
4
,
8
,
16
),
# 4 GPU, 2 experts per GPU
# 6 logical experts, 8 physical experts, 2 redundant experts
(
4
,
1
,
2
,
6
),
# 4 GPU, 2 experts per GPU
# 5 logical experts, 8 physical experts, 3 redundant experts
(
4
,
2
,
2
,
5
),
# 4 GPU, 8 experts per GPU
# 16 logical experts, 32 physical experts, 16 redundant experts
(
4
,
8
,
8
,
16
),
])
def
test_rearrange_expert_weights_with_redundancy
(
world_size
,
num_layers
,
num_local_experts
,
num_logical_experts
):
"""Test the functionality of rearranging expert weights with redundancy."""
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
f
"Need at least
{
world_size
}
GPUs to run the test"
)
@
worker_fn_wrapper
def
worker_fn
():
# Initialize model parallel (using tensor parallel as an entrypoint
# to expert parallel)
ensure_model_parallel_initialized
(
tensor_model_parallel_size
=
world_size
,
pipeline_model_parallel_size
=
1
)
ep_group
=
get_tp_group
().
cpu_group
ep_rank
=
torch
.
distributed
.
get_rank
()
device
=
torch
.
device
(
f
"cuda:
{
ep_rank
}
"
)
# Test parameters
total_physical_experts
=
world_size
*
num_local_experts
hidden_sizes
=
[
32
,
64
]
# Two different weight matrices
# Create old expert indices (with redundancy)
redundancy_config
=
create_redundancy_config
(
num_logical_experts
,
total_physical_experts
)
old_indices
=
create_expert_indices_with_redundancy
(
num_layers
,
num_logical_experts
,
total_physical_experts
,
redundancy_config
,
)
# Create new expert indices (with redundancy)
new_redundancy_config
=
create_redundancy_config
(
num_logical_experts
,
total_physical_experts
)
new_indices
=
create_expert_indices_with_redundancy
(
num_layers
,
num_logical_experts
,
total_physical_experts
,
new_redundancy_config
,
)
# Create expert weights
expert_weights
=
create_expert_weights
(
num_layers
,
num_local_experts
,
hidden_sizes
,
ep_rank
,
device
,
old_indices
)
# Execute weight rearrangement
rearrange_expert_weights_inplace
(
old_indices
,
new_indices
,
expert_weights
,
ep_group
,
is_profile
=
False
,
)
# Verify the rearrangement result
verify_expert_weights_after_shuffle
(
expert_weights
,
new_indices
,
hidden_sizes
,
ep_rank
,
num_local_experts
,
)
verify_redundant_experts_have_same_weights
(
expert_weights
,
new_indices
,
hidden_sizes
,
world_size
,
num_local_experts
,
)
distributed_run
(
worker_fn
,
world_size
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
def
test_rearrange_expert_weights_no_change
(
world_size
):
"""
Test that when the indices do not change, the weights should remain
unchanged.
"""
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
f
"Need at least
{
world_size
}
GPUs to run the test"
)
@
worker_fn_wrapper
def
worker_fn
():
ensure_model_parallel_initialized
(
tensor_model_parallel_size
=
world_size
,
pipeline_model_parallel_size
=
1
)
ep_group
=
get_tp_group
().
cpu_group
ep_rank
=
torch
.
distributed
.
get_rank
()
device
=
torch
.
device
(
f
"cuda:
{
ep_rank
}
"
)
num_layers
=
2
num_local_experts
=
2
total_physical_experts
=
world_size
*
num_local_experts
num_logical_experts
=
total_physical_experts
//
2
# Some redundancy
hidden_sizes
=
[
32
,
64
]
# Create redundancy configuration
redundancy_config
=
[
2
]
*
num_logical_experts
# Same indices - no change
indices
=
create_expert_indices_with_redundancy
(
num_layers
,
num_logical_experts
,
total_physical_experts
,
redundancy_config
)
expert_weights
=
create_expert_weights
(
num_layers
,
num_local_experts
,
hidden_sizes
,
ep_rank
,
device
,
indices
)
# Save original weights
original_weights
=
[]
for
layer_weights
in
expert_weights
:
layer_copy
=
[]
for
weight
in
layer_weights
:
layer_copy
.
append
(
weight
.
clone
())
original_weights
.
append
(
layer_copy
)
# Execute rearrangement (should be no change)
rearrange_expert_weights_inplace
(
indices
,
indices
,
# Same indices
expert_weights
,
ep_group
,
is_profile
=
False
)
# Verify that the weights have not changed
for
layer
in
range
(
num_layers
):
for
weight_idx
in
range
(
len
(
hidden_sizes
)):
torch
.
testing
.
assert_close
(
expert_weights
[
layer
][
weight_idx
],
original_weights
[
layer
][
weight_idx
],
msg
=
f
"Layer
{
layer
}
, weight
{
weight_idx
}
should remain "
f
"unchanged"
)
distributed_run
(
worker_fn
,
world_size
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
def
test_rearrange_expert_weights_profile_mode
(
world_size
):
"""Test profile mode (should not copy actual weights)"""
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
f
"Need at least
{
world_size
}
GPUs to run the test"
)
@
worker_fn_wrapper
def
worker_fn
():
ensure_model_parallel_initialized
(
tensor_model_parallel_size
=
world_size
,
pipeline_model_parallel_size
=
1
)
ep_group
=
get_tp_group
().
cpu_group
ep_rank
=
torch
.
distributed
.
get_rank
()
device
=
torch
.
device
(
f
"cuda:
{
ep_rank
}
"
)
num_layers
=
1
num_local_experts
=
2
total_physical_experts
=
world_size
*
num_local_experts
num_logical_experts
=
total_physical_experts
//
2
hidden_sizes
=
[
32
]
# Create different index distributions
old_redundancy
=
create_redundancy_config
(
num_logical_experts
,
total_physical_experts
)
new_redundancy
=
create_redundancy_config
(
num_logical_experts
,
total_physical_experts
)
old_indices
=
create_expert_indices_with_redundancy
(
num_layers
,
num_logical_experts
,
total_physical_experts
,
old_redundancy
)
new_indices
=
create_expert_indices_with_redundancy
(
num_layers
,
num_logical_experts
,
total_physical_experts
,
new_redundancy
)
expert_weights
=
create_expert_weights
(
num_layers
,
num_local_experts
,
hidden_sizes
,
ep_rank
,
device
,
old_indices
)
# Save original weights
original_weights
=
[]
for
layer_weights
in
expert_weights
:
layer_copy
=
[]
for
weight
in
layer_weights
:
layer_copy
.
append
(
weight
.
clone
())
original_weights
.
append
(
layer_copy
)
# Execute profile mode rearrangement
rearrange_expert_weights_inplace
(
old_indices
,
new_indices
,
expert_weights
,
ep_group
,
is_profile
=
True
# Profile mode
)
# In profile mode, the weights should remain unchanged
for
layer
in
range
(
num_layers
):
for
weight_idx
in
range
(
len
(
hidden_sizes
)):
torch
.
testing
.
assert_close
(
expert_weights
[
layer
][
weight_idx
],
original_weights
[
layer
][
weight_idx
],
msg
=
"In profile mode, the weights should remain unchanged"
)
distributed_run
(
worker_fn
,
world_size
)
tests/models/test_initialization.py
View file @
e9fd658a
...
@@ -31,12 +31,20 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
...
@@ -31,12 +31,20 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
text_config
=
hf_config
.
get_text_config
()
text_config
=
hf_config
.
get_text_config
()
# Ensure at least 2 expert per group
# Since `grouped_topk` assums top-2
num_experts
=
getattr
(
text_config
,
'n_group'
,
1
)
*
2
text_config
.
update
({
text_config
.
update
({
"num_layers"
:
1
,
"num_layers"
:
1
,
"num_hidden_layers"
:
1
,
"num_hidden_layers"
:
1
,
"num_experts"
:
2
,
"num_experts"
:
num_experts
,
"num_experts_per_tok"
:
2
,
"num_experts_per_tok"
:
2
,
"num_local_experts"
:
2
,
"num_local_experts"
:
num_experts
,
# Otherwise there will not be any expert layers
"first_k_dense_replace"
:
0
,
# To avoid OOM on DeepSeek-V3
"n_routed_experts"
:
num_experts
,
})
})
if
hasattr
(
hf_config
,
"vision_config"
):
if
hasattr
(
hf_config
,
"vision_config"
):
...
...
vllm/config.py
View file @
e9fd658a
...
@@ -1775,6 +1775,25 @@ class ParallelConfig:
...
@@ -1775,6 +1775,25 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel
:
bool
=
False
enable_expert_parallel
:
bool
=
False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb
:
bool
=
False
"""Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts
:
int
=
0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size
:
int
=
1000
"""Window size for expert load recording."""
eplb_step_interval
:
int
=
3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
eplb_log_balancedness
:
bool
=
False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
max_parallel_loading_workers
:
Optional
[
int
]
=
None
max_parallel_loading_workers
:
Optional
[
int
]
=
None
"""Maximum number of parallel loading workers when loading model
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
sequentially in multiple batches. To avoid RAM OOM when using tensor
...
@@ -1913,6 +1932,20 @@ class ParallelConfig:
...
@@ -1913,6 +1932,20 @@ class ParallelConfig:
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
logger
.
info
(
"Disabling V1 multiprocessing for external launcher."
)
logger
.
info
(
"Disabling V1 multiprocessing for external launcher."
)
if
self
.
enable_eplb
:
if
not
current_platform
.
is_cuda
():
raise
ValueError
(
"Expert parallelism load balancing is only supported on "
"CUDA devices now."
)
if
self
.
num_redundant_experts
<
0
:
raise
ValueError
(
"num_redundant_experts must be non-negative, but got "
f
"
{
self
.
num_redundant_experts
}
."
)
else
:
if
self
.
num_redundant_experts
!=
0
:
raise
ValueError
(
"num_redundant_experts should be used with EPLB."
f
"
{
self
.
num_redundant_experts
}
."
)
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
# We use multiprocessing by default if world_size fits on the
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
# current node and we aren't in a ray placement group.
...
...
vllm/distributed/eplb/__init__.py
0 → 100644
View file @
e9fd658a
# SPDX-License-Identifier: Apache-2.0
'''
Expert parallelism load balancer (EPLB).
'''
from
.eplb_state
import
*
from
.rebalance_algo
import
*
vllm/distributed/eplb/eplb_state.py
0 → 100644
View file @
e9fd658a
# SPDX-License-Identifier: Apache-2.0
"""
Expert parallelism load balancer (EPLB) metrics and states.
# Glossary
- **Logical Expert**: An expert that is part of the model's logical structure.
It holds a set of weights and is replicated across multiple physical
experts.
- **Redundant Expert**: To achieve load balancing, for some popular logical
experts, we create additional copies of the expert weights. During inference,
each of these copies can be routed to by the same set of tokens.
- **Physical Expert**: An expert that is instantiated on a specific device.
It is a replica of a logical expert and can be rearranged across devices.
I.e., one logical expert may have multiple sets of weights initialized on
different devices, and each of these sets is a physical expert.
- **Local Physical Expert**: A physical expert that is instantiated on the
current device.
For example: DeepSeek-R1 has 256 logical experts, so each MoE layer
has 256 sets of linear layer weights in the model parameters. If we add 32
redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in
total. And when deploying, we'll have 288 sets of linear layer weights for each
MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local
physical experts.
"""
import
time
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
import
torch
from
torch.distributed
import
all_gather
,
all_reduce
from
vllm.config
import
ParallelConfig
from
vllm.distributed.parallel_state
import
get_ep_group
,
get_node_count
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.interfaces
import
MixtureOfExperts
from
.rebalance_algo
import
rebalance_experts
from
.rebalance_execute
import
rearrange_expert_weights_inplace
logger
=
init_logger
(
__name__
)
@
dataclass
class
EplbState
:
"""EPLB metrics."""
physical_to_logical_map
:
torch
.
Tensor
"""
Mapping from physical experts to logical experts.
Shape: (num_moe_layers, num_physical_experts)
# Example
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
EP ranks, the mapping could look like this:
```
[[0, 1, 2, 3, 0, 1],
[0, 2, 0, 1, 0, 3]]
```
"""
logical_to_physical_map
:
torch
.
Tensor
"""
Mapping from logical experts to physical experts.
This is a sparse matrix, where -1 indicates no mapping.
Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)
# Example
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
EP ranks, the mapping could look like this:
```
[[[0, 4, -1],
[1, 5, -1],
[2, -1, -1],
[3, -1, -1]],
[[0, 2, 4],
[3, -1, -1],
[1, -1, -1],
[5, -1, -1]]]
```
"""
logical_replica_count
:
torch
.
Tensor
"""
Number of replicas for each logical expert.
This is exactly the non-`-1` count in the `logical_to_physical_map`.
Shape: (num_moe_layers, num_logical_experts)
# Example
For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
EP ranks, the count could look like this:
```
[[2, 2, 1, 1],
[3, 1, 1, 1]]
"""
expert_load_pass
:
torch
.
Tensor
"""
Expert load during this forward pass.
We use the token count each expert processes as the load.
Shape: (num_moe_layers, num_local_physical_experts)
"""
expert_load_window
:
torch
.
Tensor
"""
A sliding window of expert load.
Shape: (window_size, num_moe_layers, num_local_physical_experts)
"""
expert_load_window_step
:
int
=
0
"""
Current step in the sliding window.
Different from `expert_rearrangement_step`, each EP rank may have its own
`expert_load_window_step`.
"""
expert_load_window_size
:
int
=
0
"""
Size of the expert load sliding window.
This is a constant and is taken from the config.
"""
expert_rearrangement_step
:
int
=
0
"""
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
NOTE: Keep in mind that all EP ranks need to have the same
`expert_rearrangement_step` value to ensure synchronization.
Otherwise, the rearrangement will hang at collective
communication calls.
"""
expert_rearrangement_step_interval
:
int
=
0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
@
staticmethod
def
build_initial_global_physical_to_logical_map
(
num_routed_experts
:
int
,
num_redundant_experts
:
int
,
)
->
Sequence
[
int
]:
"""
Build an initial expert arrangement using the following structure:
[original routed experts, redundant experts]
Returns:
physical_to_logical_map (Sequence[int]): A list of integers,
where each integer is the index of the logical expert
that the corresponding physical expert maps to.
"""
global_physical_to_logical_map
=
list
(
range
(
num_routed_experts
))
global_physical_to_logical_map
+=
[
i
%
num_routed_experts
for
i
in
range
(
num_redundant_experts
)
]
return
global_physical_to_logical_map
@
classmethod
def
build
(
cls
,
model
:
MixtureOfExperts
,
device
:
torch
.
device
,
parallel_config
:
ParallelConfig
,
)
->
"EplbState"
:
"""
Build the initial EPLB state.
"""
physical_to_logical_map_list
=
(
cls
.
build_initial_global_physical_to_logical_map
(
model
.
num_routed_experts
,
model
.
num_redundant_experts
,
))
physical_to_logical_map
=
torch
.
tensor
(
physical_to_logical_map_list
,
device
=
device
,
)
logical_to_physical_map
=
torch
.
full
(
(
model
.
num_logical_experts
,
model
.
num_redundant_experts
+
1
),
-
1
,
device
=
device
,
)
logical_replica_count
=
torch
.
zeros
(
(
model
.
num_logical_experts
,
),
device
=
device
,
dtype
=
torch
.
long
,
)
for
i
in
range
(
model
.
num_physical_experts
):
logical_idx
=
physical_to_logical_map
[
i
]
logical_to_physical_map
[
logical_idx
,
logical_replica_count
[
logical_idx
]]
=
i
logical_replica_count
[
logical_idx
]
+=
1
# Duplicate initial mapping for all layers
physical_to_logical_map
=
physical_to_logical_map
.
unsqueeze
(
0
).
expand
(
model
.
num_moe_layers
,
-
1
,
).
contiguous
()
logical_to_physical_map
=
logical_to_physical_map
.
unsqueeze
(
0
).
expand
(
model
.
num_moe_layers
,
-
1
,
-
1
,
).
contiguous
()
logical_replica_count
=
logical_replica_count
.
unsqueeze
(
0
).
expand
(
model
.
num_moe_layers
,
-
1
,
).
contiguous
()
expert_load_pass
=
torch
.
zeros
(
(
model
.
num_moe_layers
,
model
.
num_local_physical_experts
),
dtype
=
torch
.
int32
,
device
=
device
,
)
expert_load_window_size
=
parallel_config
.
eplb_window_size
expert_load_window
=
torch
.
zeros
(
(
expert_load_window_size
,
model
.
num_moe_layers
,
model
.
num_local_physical_experts
),
dtype
=
torch
.
int32
,
device
=
device
,
)
# Set the initial progress of rearrangement to 3/4
eplb_step_interval
=
parallel_config
.
eplb_step_interval
expert_rearrangement_step
=
max
(
0
,
eplb_step_interval
-
eplb_step_interval
//
4
)
model
.
set_eplb_state
(
expert_load_pass
,
logical_to_physical_map
,
logical_replica_count
,
)
return
cls
(
physical_to_logical_map
,
logical_to_physical_map
,
logical_replica_count
,
expert_load_pass
,
expert_load_window
,
expert_load_window_size
=
expert_load_window_size
,
expert_rearrangement_step
=
expert_rearrangement_step
,
expert_rearrangement_step_interval
=
eplb_step_interval
,
)
def
step
(
self
,
model
:
MixtureOfExperts
,
is_dummy
:
bool
=
False
,
is_profile
:
bool
=
False
,
log_stats
:
bool
=
False
)
->
None
:
"""
Step the EPLB state.
Args:
model (MixtureOfExperts): The MoE model.
is_dummy (bool): If `True`, this is a dummy step and the load
metrics recorded in this forward pass will not count. Defaults
to `False`.
is_profile (bool): If `True`, perform a dummy rearrangement
with maximum communication cost. This is used in `profile_run`
to reserve enough memory for the communication buffer.
log_stats (bool): If `True`, log the expert load metrics.
# Stats
The metrics are all summed up across layers.
- `avg_tokens`: The average load across ranks.
- `max_tokens`: The maximum load across ranks.
- `balancedness`: The ratio of average load to maximum load.
"""
if
is_profile
:
self
.
rearrange
(
model
,
is_profile
=
True
)
return
if
is_dummy
:
# Do not record load metrics for dummy steps
self
.
expert_load_pass
.
zero_
()
if
log_stats
:
# `num_tokens`: (num_moe_layers,)
num_tokens
=
self
.
expert_load_pass
.
sum
(
dim
=-
1
)
# Collect load metrics from all ranks
ep_group
=
get_ep_group
().
device_group
num_tokens_list
=
[
torch
.
empty_like
(
num_tokens
)
for
_
in
range
(
ep_group
.
size
())
]
all_gather
(
num_tokens_list
,
num_tokens
,
group
=
ep_group
)
# Stack to get (num_ranks, num_moe_layers)
num_tokens_per_rank
=
torch
.
stack
(
num_tokens_list
).
float
()
# Compute balancedness ratio:
# for each layer:
# (mean load across ranks) / (max load across ranks)
avg_tokens_tensor
=
num_tokens_per_rank
.
mean
(
dim
=
0
).
sum
(
dim
=
0
)
max_tokens_tensor
=
num_tokens_per_rank
.
max
(
dim
=
0
).
values
.
sum
(
dim
=
0
)
# Just to make type checker happy
tokens_tensors
:
list
[
float
]
=
torch
.
stack
(
[
avg_tokens_tensor
,
max_tokens_tensor
]).
tolist
()
avg_tokens
,
max_tokens
=
tokens_tensors
balancedness
=
avg_tokens
/
max_tokens
if
max_tokens
>
0
else
0.0
if
ep_group
.
rank
()
==
0
:
logger
.
info
(
"EPLB step: avg_tokens=%.2f, max_tokens=%d, "
"balancedness=%.4f"
,
avg_tokens
,
max_tokens
,
balancedness
)
# Update the expert load sliding window
if
not
is_dummy
:
self
.
expert_load_window
[
self
.
expert_load_window_step
]
=
(
self
.
expert_load_pass
.
clone
())
self
.
expert_load_window_step
+=
1
if
self
.
expert_load_window_step
>=
self
.
expert_load_window_size
:
self
.
expert_load_window_step
=
0
self
.
expert_load_pass
.
zero_
()
# Step the expert rearrangement step
# Note that even if this is a dummy step, we still increment the
# rearrangement step and perform rearrangement to ensure all ranks are
# performing collective communication.
self
.
expert_rearrangement_step
+=
1
if
(
self
.
expert_rearrangement_step
>=
self
.
expert_rearrangement_step_interval
):
self
.
expert_rearrangement_step
=
0
self
.
rearrange
(
model
)
def
rearrange
(
self
,
model
:
MixtureOfExperts
,
is_profile
:
bool
=
False
)
->
None
:
"""
Rearrange the experts according to the current load.
"""
ep_group
=
get_ep_group
().
device_group
ep_rank
=
ep_group
.
rank
()
time_start
=
None
is_main_rank
=
ep_rank
==
0
if
is_main_rank
:
torch
.
cuda
.
synchronize
()
time_start
=
time
.
perf_counter
()
logger
.
info
(
"Rearranging experts %s..."
,
"(profile)"
if
is_profile
else
""
)
# This mapping is only used here, so we do not store it in the state
physical_expert_start
=
ep_rank
*
model
.
num_local_physical_experts
physical_expert_end
=
(
physical_expert_start
+
model
.
num_local_physical_experts
)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map
=
self
.
physical_to_logical_map
[
:,
physical_expert_start
:
physical_expert_end
,
]
# Map the local physical expert load to global logical experts
logical_expert_load_window
=
torch
.
zeros
(
self
.
expert_load_window_size
,
model
.
num_moe_layers
,
model
.
num_logical_experts
,
dtype
=
self
.
expert_load_window
.
dtype
,
device
=
self
.
expert_load_window
.
device
,
)
logical_expert_load_window
.
scatter_add_
(
dim
=-
1
,
index
=
local_physical_to_logical_map
.
unsqueeze
(
0
).
expand_as
(
self
.
expert_load_window
).
long
(),
src
=
self
.
expert_load_window
,
)
# Perform all-reduce to get the expert load across all ranks
global_expert_load_window
=
logical_expert_load_window
.
sum
(
dim
=
0
)
all_reduce
(
global_expert_load_window
,
group
=
ep_group
)
# TODO(bowen): Treat differently for prefill and decode nodes
num_replicas
=
model
.
num_physical_experts
num_groups
=
model
.
num_expert_groups
num_nodes
=
get_node_count
()
num_gpus
=
ep_group
.
size
()
if
num_gpus
%
num_nodes
!=
0
:
logger
.
warning_once
(
f
"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.
\n
"
f
"
{
num_gpus
=
}
,
{
num_nodes
=
}
"
)
# Get new expert mappings
(
new_physical_to_logical_map
,
new_logical_to_physical_map
,
new_logical_replica_count
,
)
=
(
rebalance_experts
(
global_expert_load_window
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
,
))
# Update expert weights
rearrange_expert_weights_inplace
(
self
.
physical_to_logical_map
,
new_physical_to_logical_map
,
model
.
expert_weights
,
ep_group
,
is_profile
,
)
if
not
is_profile
:
self
.
physical_to_logical_map
.
copy_
(
new_physical_to_logical_map
)
self
.
logical_to_physical_map
.
copy_
(
new_logical_to_physical_map
)
self
.
logical_replica_count
.
copy_
(
new_logical_replica_count
)
if
is_main_rank
:
assert
time_start
is
not
None
torch
.
cuda
.
synchronize
()
time_end
=
time
.
perf_counter
()
logger
.
info
(
"Rearranged experts%sin %.2f seconds."
,
" (profile) "
if
is_profile
else
" "
,
time_end
-
time_start
,
)
vllm/distributed/eplb/rebalance_algo.py
0 → 100644
View file @
e9fd658a
# SPDX-License-Identifier: Apache-2.0
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import
torch
def
balanced_packing
(
weight
:
torch
.
Tensor
,
num_packs
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers
,
num_groups
=
weight
.
shape
assert
num_groups
%
num_packs
==
0
groups_per_pack
=
num_groups
//
num_packs
if
groups_per_pack
==
1
:
pack_index
=
torch
.
arange
(
weight
.
size
(
-
1
),
dtype
=
torch
.
int64
,
device
=
weight
.
device
).
expand
(
weight
.
shape
)
rank_in_pack
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int64
)
return
pack_index
,
rank_in_pack
indices
=
weight
.
float
().
sort
(
-
1
,
descending
=
True
).
indices
.
cpu
()
pack_index
=
torch
.
full_like
(
weight
,
fill_value
=-
1
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
rank_in_pack
=
torch
.
full_like
(
pack_index
,
fill_value
=-
1
)
for
i
in
range
(
num_layers
):
pack_weights
=
[
0
]
*
num_packs
pack_items
=
[
0
]
*
num_packs
for
group
in
indices
[
i
]:
pack
=
min
(
(
i
for
i
in
range
(
num_packs
)
if
pack_items
[
i
]
<
groups_per_pack
),
key
=
pack_weights
.
__getitem__
,
)
assert
pack_items
[
pack
]
<
groups_per_pack
pack_index
[
i
,
group
]
=
pack
rank_in_pack
[
i
,
group
]
=
pack_items
[
pack
]
pack_weights
[
pack
]
+=
weight
[
i
,
group
]
pack_items
[
pack
]
+=
1
return
pack_index
,
rank_in_pack
def
replicate_experts
(
weight
:
torch
.
Tensor
,
num_phy
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n
,
num_log
=
weight
.
shape
num_redundant
=
num_phy
-
num_log
assert
num_redundant
>=
0
device
=
weight
.
device
phy2log
=
torch
.
arange
(
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
).
repeat
(
n
,
1
)
rank
=
torch
.
zeros
(
n
,
num_phy
,
dtype
=
torch
.
int64
,
device
=
device
)
logcnt
=
torch
.
ones
(
n
,
num_log
,
dtype
=
torch
.
int64
,
device
=
device
)
arangen
=
torch
.
arange
(
n
,
dtype
=
torch
.
int64
,
device
=
device
)
for
i
in
range
(
num_log
,
num_phy
):
redundant_indices
=
(
weight
/
logcnt
).
max
(
dim
=-
1
).
indices
phy2log
[:,
i
]
=
redundant_indices
rank
[:,
i
]
=
logcnt
[
arangen
,
redundant_indices
]
logcnt
[
arangen
,
redundant_indices
]
+=
1
return
phy2log
,
rank
,
logcnt
def
rebalance_experts_hierarchical
(
weight
:
torch
.
Tensor
,
num_physical_experts
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers
,
num_logical_experts
=
weight
.
shape
assert
num_logical_experts
%
num_groups
==
0
group_size
=
num_logical_experts
//
num_groups
assert
num_groups
%
num_nodes
==
0
groups_per_node
=
num_groups
//
num_nodes
assert
num_gpus
%
num_nodes
==
0
assert
num_physical_experts
%
num_gpus
==
0
phy_experts_per_gpu
=
num_physical_experts
//
num_gpus
def
inverse
(
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inv
=
torch
.
empty_like
(
perm
)
inv
.
scatter_
(
1
,
perm
,
torch
.
arange
(
perm
.
size
(
1
),
dtype
=
torch
.
int64
,
device
=
perm
.
device
).
expand
(
perm
.
shape
),
)
return
inv
# Step 1: pack groups to nodes
tokens_per_group
=
weight
.
unflatten
(
-
1
,
(
num_groups
,
group_size
)).
sum
(
-
1
)
group_pack_index
,
group_rank_in_pack
=
balanced_packing
(
tokens_per_group
,
num_nodes
)
log2mlog
=
(((
group_pack_index
*
groups_per_node
+
group_rank_in_pack
)
*
group_size
).
unsqueeze
(
-
1
)
+
torch
.
arange
(
group_size
,
dtype
=
torch
.
int64
,
device
=
group_pack_index
.
device
)).
flatten
(
-
2
)
mlog2log
=
inverse
(
log2mlog
)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog
=
weight
.
gather
(
-
1
,
mlog2log
).
view
(
-
1
,
num_logical_experts
//
num_nodes
)
phy2mlog
,
phyrank
,
mlogcnt
=
replicate_experts
(
tokens_per_mlog
,
num_physical_experts
//
num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy
=
(
tokens_per_mlog
/
mlogcnt
).
gather
(
-
1
,
phy2mlog
)
pack_index
,
rank_in_pack
=
balanced_packing
(
tokens_per_phy
,
num_gpus
//
num_nodes
)
phy2pphy
=
pack_index
*
phy_experts_per_gpu
+
rank_in_pack
pphy2phy
=
inverse
(
phy2pphy
)
pphy2mlog
=
phy2mlog
.
gather
(
-
1
,
pphy2phy
)
# [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog
=
(
pphy2mlog
.
view
(
num_layers
,
num_nodes
,
-
1
)
+
torch
.
arange
(
0
,
num_logical_experts
,
num_logical_experts
//
num_nodes
,
device
=
group_pack_index
.
device
,
).
view
(
1
,
-
1
,
1
)).
flatten
(
-
2
)
pphy2log
=
mlog2log
.
gather
(
-
1
,
pphy2mlog
)
pphyrank
=
phyrank
.
gather
(
-
1
,
pphy2phy
).
view
(
num_layers
,
-
1
)
logcnt
=
mlogcnt
.
view
(
num_layers
,
-
1
).
gather
(
-
1
,
log2mlog
)
return
pphy2log
,
pphyrank
,
logcnt
def
rebalance_experts
(
weight
:
torch
.
Tensor
,
num_replicas
:
int
,
num_groups
:
int
,
num_nodes
:
int
,
num_gpus
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of
each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica
indices for each expert
expert_count: [layers, num_logical_experts], number of physical
replicas for each logical expert
"""
num_layers
,
num_logical_experts
=
weight
.
shape
weight
=
weight
.
float
().
cpu
()
if
num_groups
%
num_nodes
==
0
:
# use hierarchical load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
num_groups
,
num_nodes
,
num_gpus
)
else
:
# use global load-balance policy
phy2log
,
phyrank
,
logcnt
=
rebalance_experts_hierarchical
(
weight
,
num_replicas
,
1
,
1
,
num_gpus
)
num_redundant_experts
=
num_replicas
-
num_logical_experts
maxlogcnt
=
num_redundant_experts
+
1
log2phy
:
torch
.
Tensor
=
torch
.
full
(
(
num_layers
,
num_logical_experts
,
maxlogcnt
),
-
1
,
dtype
=
torch
.
int64
,
device
=
logcnt
.
device
,
)
log2phy
.
view
(
num_layers
,
-
1
).
scatter_
(
-
1
,
phy2log
*
maxlogcnt
+
phyrank
,
torch
.
arange
(
num_replicas
,
dtype
=
torch
.
int64
,
device
=
log2phy
.
device
).
expand
(
num_layers
,
-
1
),
)
return
phy2log
,
log2phy
,
logcnt
__all__
=
[
"rebalance_experts"
]
vllm/distributed/eplb/rebalance_execute.py
0 → 100644
View file @
e9fd658a
# SPDX-License-Identifier: Apache-2.0
"""
The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs.
"""
from
collections.abc
import
Iterable
,
MutableSequence
,
Sequence
from
functools
import
partial
import
torch
from
torch.distributed
import
(
P2POp
,
ProcessGroup
,
all_gather
,
batch_isend_irecv
,
get_global_rank
)
def
idx_local_to_global
(
local_idx
:
int
,
local_cnt
:
int
,
ep_rank
:
int
,
)
->
int
:
"""
Convert a local expert index to a global expert index.
"""
return
ep_rank
*
local_cnt
+
local_idx
def
idx_global_to_local
(
global_idx
:
int
,
local_cnt
:
int
,
ep_rank
:
int
,
)
->
int
:
"""
Convert a global expert index to a local expert index.
"""
return
global_idx
-
ep_rank
*
local_cnt
def
global_idx_to_rank
(
global_idx
:
int
,
local_cnt
:
int
,
)
->
int
:
"""
Convert a global expert index to a rank index.
"""
return
global_idx
//
local_cnt
def
get_ep_ranks_with_expert
(
idx
:
int
,
num_local_experts
:
int
,
old_indices
:
Sequence
[
int
],
new_indices
:
Sequence
[
int
],
)
->
tuple
[
MutableSequence
[
int
],
MutableSequence
[
int
]]:
"""
Get the ranks of the experts that need to be exchanged.
Args:
idx: The index of the expert.
num_local_experts: The number of local experts.
old_indices: The old indices of the experts.
new_indices: The new indices of the experts.
Returns:
A tuple of two lists:
- The ranks of the experts that need to be sent.
- The ranks of the experts that need to be received.
"""
global2rank
=
partial
(
global_idx_to_rank
,
local_cnt
=
num_local_experts
,
)
ranks_to_send
:
list
[
int
]
=
[]
ranks_to_recv
:
list
[
int
]
=
[]
for
i
,
e
in
enumerate
(
old_indices
):
if
e
==
idx
:
rank
=
global2rank
(
i
)
if
not
ranks_to_send
or
ranks_to_send
[
-
1
]
!=
rank
:
ranks_to_send
.
append
(
rank
)
for
i
,
e
in
enumerate
(
new_indices
):
if
e
==
idx
:
rank
=
global2rank
(
i
)
if
not
ranks_to_recv
or
ranks_to_recv
[
-
1
]
!=
rank
:
ranks_to_recv
.
append
(
rank
)
# Remove those ranks that can get this expert locally.
ranks_to_send_set
=
set
(
ranks_to_send
)
ranks_to_recv_actual
=
[
rank
for
rank
in
ranks_to_recv
if
rank
not
in
ranks_to_send_set
]
return
ranks_to_send
,
ranks_to_recv_actual
def
shuffle_layer
(
num_local_experts
:
int
,
ep_rank
:
int
,
old_indices
:
Sequence
[
int
],
new_indices
:
Sequence
[
int
],
expert_weights
:
Iterable
[
torch
.
Tensor
],
expert_weights_buffer
:
Sequence
[
torch
.
Tensor
],
ep_group
:
ProcessGroup
,
)
->
None
:
"""
Perform expert weights rearrangement of one layer.
"""
local2global
=
partial
(
idx_local_to_global
,
local_cnt
=
num_local_experts
,
ep_rank
=
ep_rank
,
)
# 0. Do nothing for experts that did not change.
is_unchanged
=
[
old_indices
[
local2global
(
i
)]
==
new_indices
[
local2global
(
i
)]
for
i
in
range
(
num_local_experts
)
]
# 1. Perform weight copy inside the local rank.
is_received_locally
=
is_unchanged
[:]
for
src
in
range
(
num_local_experts
):
src_global
=
local2global
(
src
)
for
dst
in
range
(
num_local_experts
):
dst_global
=
local2global
(
dst
)
if
is_received_locally
[
dst
]:
continue
if
old_indices
[
src_global
]
==
new_indices
[
dst_global
]:
is_received_locally
[
dst
]
=
True
for
weight
,
buffer
in
zip
(
expert_weights
,
expert_weights_buffer
):
buffer
[
dst
].
copy_
(
weight
[
src
])
p2p_ops
:
list
[
P2POp
]
=
[]
# 2. Initiate sending of weights.
experts_send_loc
:
dict
[
int
,
int
]
=
{}
for
src
in
range
(
num_local_experts
):
expert
=
old_indices
[
local2global
(
src
)]
if
expert
in
experts_send_loc
:
continue
experts_send_loc
[
expert
]
=
src
# We need to sort here to match send/recv
for
expert
,
src
in
sorted
(
experts_send_loc
.
items
()):
ranks_to_send
,
ranks_to_recv
=
get_ep_ranks_with_expert
(
expert
,
num_local_experts
,
old_indices
,
new_indices
,
)
# Calculate the ranks to send by this rank
num_dst_per_sender
=
len
(
ranks_to_recv
)
//
len
(
ranks_to_send
)
sender_pos
=
ranks_to_send
.
index
(
ep_rank
)
recv_begin
=
sender_pos
*
num_dst_per_sender
recv_end
=
recv_begin
+
num_dst_per_sender
recv_ranks
=
ranks_to_recv
[
recv_begin
:
recv_end
]
# Tackle remainders
remainder_start
=
len
(
ranks_to_send
)
*
num_dst_per_sender
recver_pos
=
remainder_start
+
sender_pos
if
recver_pos
<
len
(
ranks_to_recv
):
recv_ranks
.
append
(
ranks_to_recv
[
recver_pos
])
for
dst
in
recv_ranks
:
dst_global
=
get_global_rank
(
ep_group
,
dst
)
p2p_ops
+=
[
P2POp
(
torch
.
distributed
.
isend
,
weight
[
src
],
dst_global
,
)
for
weight
in
expert_weights
]
# 3. Initiate receiving of weights.
experts_recv_loc
:
dict
[
int
,
int
]
=
{}
for
dst
in
range
(
num_local_experts
):
if
is_received_locally
[
dst
]:
continue
expert
=
new_indices
[
local2global
(
dst
)]
if
expert
in
experts_recv_loc
:
continue
experts_recv_loc
[
expert
]
=
dst
# We need to sort here to match send/recv
for
expert
,
dst
in
sorted
(
experts_recv_loc
.
items
()):
ranks_to_send
,
ranks_to_recv
=
get_ep_ranks_with_expert
(
expert
,
num_local_experts
,
old_indices
,
new_indices
,
)
# Calculate the rank to recv by this rank
num_dst_per_sender
=
len
(
ranks_to_recv
)
//
len
(
ranks_to_send
)
recver_pos
=
ranks_to_recv
.
index
(
ep_rank
)
remainder_start
=
len
(
ranks_to_send
)
*
num_dst_per_sender
if
recver_pos
<
remainder_start
:
src
=
ranks_to_send
[
recver_pos
//
num_dst_per_sender
]
else
:
src
=
ranks_to_send
[
recver_pos
-
remainder_start
]
src_global
=
get_global_rank
(
ep_group
,
src
)
p2p_ops
+=
[
P2POp
(
torch
.
distributed
.
irecv
,
weight
[
dst
],
src_global
,
)
for
weight
in
expert_weights_buffer
]
# 4. Execute the P2P operations. The real communication happens here.
if
p2p_ops
:
reqs
=
batch_isend_irecv
(
p2p_ops
)
for
req
in
reqs
:
req
.
wait
()
# 5. Copy the weights from the buffer back to the original weights.
for
dst
in
range
(
num_local_experts
):
if
is_unchanged
[
dst
]:
continue
if
is_received_locally
[
dst
]:
for
weight
,
buffer
in
zip
(
expert_weights
,
expert_weights_buffer
):
weight
[
dst
].
copy_
(
buffer
[
dst
])
else
:
expert
=
new_indices
[
local2global
(
dst
)]
src
=
experts_recv_loc
[
expert
]
for
weight
,
buffer
in
zip
(
expert_weights
,
expert_weights_buffer
):
weight
[
dst
].
copy_
(
buffer
[
src
])
def
rearrange_expert_weights_inplace
(
old_global_expert_indices
:
torch
.
Tensor
,
new_global_expert_indices
:
torch
.
Tensor
,
expert_weights
:
Sequence
[
Iterable
[
torch
.
Tensor
]],
ep_group
:
ProcessGroup
,
is_profile
:
bool
=
False
,
)
->
None
:
"""
Rearranges the expert weights in place according to the new expert indices.
The value of the indices arguments are logical indices of the experts,
while keys are physical.
Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
"""
num_moe_layers
,
num_physical_experts
=
old_global_expert_indices
.
shape
assert
len
(
expert_weights
)
==
num_moe_layers
num_local_physical_experts
=
next
(
iter
(
expert_weights
[
0
])).
shape
[
0
]
assert
new_global_expert_indices
.
shape
==
(
num_moe_layers
,
num_physical_experts
)
ep_rank
=
ep_group
.
rank
()
ep_size
=
ep_group
.
size
()
assert
num_physical_experts
==
ep_size
*
num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
expert_weights_buffer
=
[
torch
.
empty_like
(
w
)
for
w
in
expert_weights
[
0
]]
if
is_profile
:
# Maximum send size is to send all local experts to all ranks,
# So we use a dummy `all_gather` to reserve enough communication buffer
for
weight
,
buffer
in
zip
(
expert_weights
[
0
],
expert_weights_buffer
):
# A `/dev/null`-like buffer to avoid real memory allocation
dummy_recv_buffer
=
[
buffer
for
_
in
range
(
ep_size
)]
# NOTE(bowen): Needed this barrier to avoid OOM during actual
# execution. I'm not very sure why this is needed
torch
.
distributed
.
barrier
()
all_gather
(
dummy_recv_buffer
,
weight
,
group
=
ep_group
,
)
return
for
layer
in
range
(
num_moe_layers
):
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch
.
cuda
.
synchronize
()
shuffle_layer
(
num_local_physical_experts
,
ep_rank
,
old_global_expert_indices
[
layer
].
tolist
(),
new_global_expert_indices
[
layer
].
tolist
(),
expert_weights
[
layer
],
expert_weights_buffer
,
ep_group
,
)
__all__
=
[
"rearrange_expert_weights_inplace"
]
vllm/engine/arg_utils.py
View file @
e9fd658a
...
@@ -320,6 +320,11 @@ class EngineArgs:
...
@@ -320,6 +320,11 @@ class EngineArgs:
data_parallel_rpc_port
:
Optional
[
int
]
=
None
data_parallel_rpc_port
:
Optional
[
int
]
=
None
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
enable_eplb
:
bool
=
ParallelConfig
.
enable_eplb
num_redundant_experts
:
int
=
ParallelConfig
.
num_redundant_experts
eplb_window_size
:
int
=
ParallelConfig
.
eplb_window_size
eplb_step_interval
:
int
=
ParallelConfig
.
eplb_step_interval
eplb_log_balancedness
:
bool
=
ParallelConfig
.
eplb_log_balancedness
max_parallel_loading_workers
:
Optional
[
max_parallel_loading_workers
:
Optional
[
int
]
=
ParallelConfig
.
max_parallel_loading_workers
int
]
=
ParallelConfig
.
max_parallel_loading_workers
block_size
:
Optional
[
BlockSize
]
=
CacheConfig
.
block_size
block_size
:
Optional
[
BlockSize
]
=
CacheConfig
.
block_size
...
@@ -666,6 +671,16 @@ class EngineArgs:
...
@@ -666,6 +671,16 @@ class EngineArgs:
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--enable-expert-parallel"
,
"--enable-expert-parallel"
,
**
parallel_kwargs
[
"enable_expert_parallel"
])
**
parallel_kwargs
[
"enable_expert_parallel"
])
parallel_group
.
add_argument
(
"--enable-eplb"
,
**
parallel_kwargs
[
"enable_eplb"
])
parallel_group
.
add_argument
(
"--num-redundant-experts"
,
**
parallel_kwargs
[
"num_redundant_experts"
])
parallel_group
.
add_argument
(
"--eplb-window-size"
,
**
parallel_kwargs
[
"eplb_window_size"
])
parallel_group
.
add_argument
(
"--eplb-step-interval"
,
**
parallel_kwargs
[
"eplb_step_interval"
])
parallel_group
.
add_argument
(
"--eplb-log-balancedness"
,
**
parallel_kwargs
[
"eplb_log_balancedness"
])
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--max-parallel-loading-workers"
,
"--max-parallel-loading-workers"
,
**
parallel_kwargs
[
"max_parallel_loading_workers"
])
**
parallel_kwargs
[
"max_parallel_loading_workers"
])
...
@@ -1135,6 +1150,11 @@ class EngineArgs:
...
@@ -1135,6 +1150,11 @@ class EngineArgs:
data_parallel_rpc_port
=
data_parallel_rpc_port
,
data_parallel_rpc_port
=
data_parallel_rpc_port
,
data_parallel_backend
=
data_parallel_backend
,
data_parallel_backend
=
data_parallel_backend
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
num_redundant_experts
,
eplb_window_size
=
self
.
eplb_window_size
,
eplb_step_interval
=
self
.
eplb_step_interval
,
eplb_log_balancedness
=
self
.
eplb_log_balancedness
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
max_parallel_loading_workers
=
self
.
max_parallel_loading_workers
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
disable_custom_all_reduce
=
self
.
disable_custom_all_reduce
,
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
e9fd658a
...
@@ -3,9 +3,10 @@
...
@@ -3,9 +3,10 @@
import
importlib
import
importlib
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Literal
,
Optional
,
Union
,
overload
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -20,6 +21,7 @@ from vllm.distributed import (get_dp_group, get_ep_group,
...
@@ -20,6 +21,7 @@ from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
...
@@ -435,6 +437,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -435,6 +437,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -574,7 +580,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -574,7 +580,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet."
)
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
x
=
x
,
layer
=
layer
,
layer
=
layer
,
...
@@ -821,6 +835,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -821,6 +835,7 @@ class FusedMoE(torch.nn.Module):
reduce_results: Whether to all all_reduce on the output of the layer
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -845,6 +860,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -845,6 +860,8 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
):
):
super
().
__init__
()
super
().
__init__
()
if
params_dtype
is
None
:
if
params_dtype
is
None
:
...
@@ -860,7 +877,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -860,7 +877,7 @@ class FusedMoE(torch.nn.Module):
get_dp_group
().
world_size
),
get_dp_group
().
world_size
),
vllm_parallel_config
=
vllm_config
.
parallel_config
))
vllm_parallel_config
=
vllm_config
.
parallel_config
))
self
.
global_num_experts
=
num_experts
self
.
global_num_experts
=
num_experts
+
num_redundant_experts
# For smuggling this layer into the fused moe custom op
# For smuggling this layer into the fused moe custom op
compilation_config
=
vllm_config
.
compilation_config
compilation_config
=
vllm_config
.
compilation_config
...
@@ -869,8 +886,20 @@ class FusedMoE(torch.nn.Module):
...
@@ -869,8 +886,20 @@ class FusedMoE(torch.nn.Module):
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
layer_name
=
prefix
self
.
enable_eplb
=
enable_eplb
self
.
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
self
.
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
self
.
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
# Determine expert maps
# Determine expert maps
if
self
.
use_ep
:
if
self
.
use_ep
:
if
self
.
enable_eplb
:
assert
self
.
global_num_experts
%
self
.
ep_size
==
0
,
\
"EPLB currently only supports even distribution of "
\
"experts across ranks."
else
:
assert
num_redundant_experts
==
0
,
\
"Redundant experts are only supported with EPLB."
self
.
local_num_experts
,
self
.
expert_map
=
determine_expert_map
(
self
.
local_num_experts
,
self
.
expert_map
=
determine_expert_map
(
ep_size
=
self
.
ep_size
,
ep_size
=
self
.
ep_size
,
ep_rank
=
self
.
ep_rank
,
ep_rank
=
self
.
ep_rank
,
...
@@ -937,6 +966,20 @@ class FusedMoE(torch.nn.Module):
...
@@ -937,6 +966,20 @@ class FusedMoE(torch.nn.Module):
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
self
.
quant_method
=
quant_method
self
.
quant_method
=
quant_method
if
self
.
enable_eplb
:
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8MoEMethod
)
if
not
isinstance
(
quant_method
,
Fp8MoEMethod
):
# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
# design causes duplicated work when extending to new
# quantization methods, so I'm leaving it for now.
# If you plan to add support for more quantization methods,
# please refer to the implementation in `Fp8MoEMethod`.
raise
NotImplementedError
(
"EPLB is only supported for FP8 "
"quantization for now."
)
moe_quant_params
=
{
moe_quant_params
=
{
"num_experts"
:
self
.
local_num_experts
,
"num_experts"
:
self
.
local_num_experts
,
"hidden_size"
:
hidden_size
,
"hidden_size"
:
hidden_size
,
...
@@ -965,8 +1008,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -965,8 +1008,9 @@ class FusedMoE(torch.nn.Module):
dtype
=
act_dtype
,
dtype
=
act_dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
# Note here we use `num_experts` which is logical expert count
self
.
batched_router_logits
=
torch
.
zeros
(
self
.
batched_router_logits
=
torch
.
zeros
(
(
envs
.
VLLM_MOE_DP_CHUNK_SIZE
,
self
.
global_
num_experts
),
(
envs
.
VLLM_MOE_DP_CHUNK_SIZE
,
num_experts
),
dtype
=
act_dtype
,
dtype
=
act_dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
...
@@ -1130,13 +1174,33 @@ class FusedMoE(torch.nn.Module):
...
@@ -1130,13 +1174,33 @@ class FusedMoE(torch.nn.Module):
return
expert_id
return
expert_id
return
self
.
expert_map
[
expert_id
].
item
()
return
self
.
expert_map
[
expert_id
].
item
()
@
overload
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
,
return_success
:
Literal
[
False
])
->
None
:
...
@
overload
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
return_success
:
Literal
[
True
])
->
bool
:
...
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
return_success
:
bool
=
False
)
->
Optional
[
bool
]:
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
if
expert_id
==
-
1
:
return
# Failed to load this param since it's not local to this rank
return
False
if
return_success
else
None
# Hereafter, `expert_id` is local physical id
quant_method_name
=
self
.
quant_method
.
__class__
.
__name__
quant_method_name
=
self
.
quant_method
.
__class__
.
__name__
# compressed-tensors checkpoints with packed weights are stored flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# TODO (mgoin): check self.quant_method.quant_config.quant_format
...
@@ -1163,7 +1227,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1163,7 +1227,7 @@ class FusedMoE(torch.nn.Module):
if
is_gguf_weight_type
:
if
is_gguf_weight_type
:
param
.
weight_type
=
loaded_weight
.
item
()
param
.
weight_type
=
loaded_weight
.
item
()
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
return
return
True
if
return_success
else
None
# is_transposed: if the dim to shard the weight
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be flipped. Required by GPTQ, compressed-tensors
...
@@ -1202,7 +1266,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1202,7 +1266,7 @@ class FusedMoE(torch.nn.Module):
self
.
_load_single_value
(
param
=
param
,
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
expert_id
=
expert_id
)
return
return
True
if
return_success
else
None
# Case g_idx
# Case g_idx
if
"g_idx"
in
weight_name
:
if
"g_idx"
in
weight_name
:
...
@@ -1211,7 +1275,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1211,7 +1275,7 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
self
.
tp_rank
)
tp_rank
=
self
.
tp_rank
)
return
return
True
if
return_success
else
None
if
"ModelOpt"
in
quant_method_name
:
if
"ModelOpt"
in
quant_method_name
:
if
(
'weight_scale_2'
in
weight_name
if
(
'weight_scale_2'
in
weight_name
...
@@ -1227,7 +1291,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1227,7 +1291,7 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
self
.
tp_rank
)
tp_rank
=
self
.
tp_rank
)
return
return
True
if
return_success
else
None
# Case weight scales, zero_points and offset
# Case weight scales, zero_points and offset
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
...
@@ -1264,7 +1328,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1264,7 +1328,7 @@ class FusedMoE(torch.nn.Module):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
return
True
if
return_success
else
None
# Case weight_shape
# Case weight_shape
if
"weight_shape"
in
weight_name
:
if
"weight_shape"
in
weight_name
:
...
@@ -1272,7 +1336,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -1272,7 +1336,7 @@ class FusedMoE(torch.nn.Module):
self
.
_load_single_value
(
param
=
param
,
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
expert_id
=
expert_id
)
return
return
True
if
return_success
else
None
# Case model weights
# Case model weights
if
"weight"
in
weight_name
:
if
"weight"
in
weight_name
:
...
@@ -1282,10 +1346,46 @@ class FusedMoE(torch.nn.Module):
...
@@ -1282,10 +1346,46 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
self
.
tp_rank
)
tp_rank
=
self
.
tp_rank
)
return
return
True
if
return_success
else
None
return
False
if
return_success
else
None
def
get_expert_weights
(
self
)
->
Iterable
[
torch
.
Tensor
]:
weights
=
list
(
self
.
named_parameters
())
assert
all
(
weight
.
is_contiguous
()
for
_
,
weight
in
weights
)
# Filter out the non-expert weights.
# `e_score_correction_bias` is a bias for each logical expert,
# with shape (num_logical_experts,), not an expert weight.
NON_EXPERT_WEIGHTS
=
{
"e_score_correction_bias"
,
}
return
[
weight
.
view
(
self
.
local_num_experts
,
-
1
)
for
name
,
weight
in
weights
if
name
not
in
NON_EXPERT_WEIGHTS
]
def
set_eplb_state
(
self
,
moe_layer_idx
:
int
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
None
:
"""
Register the EPLB state in this layer.
This is used later in forward pass, where we get the expert mapping
and record the load metrics in `expert_load_view`.
"""
self
.
expert_load_view
=
expert_load_view
[
moe_layer_idx
]
self
.
logical_to_physical_map
=
logical_to_physical_map
[
moe_layer_idx
]
self
.
logical_replica_count
=
logical_replica_count
[
moe_layer_idx
]
@
staticmethod
@
staticmethod
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
use_grouped_topk
:
bool
,
use_grouped_topk
:
bool
,
...
@@ -1295,10 +1395,28 @@ class FusedMoE(torch.nn.Module):
...
@@ -1295,10 +1395,28 @@ class FusedMoE(torch.nn.Module):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
indices_type
:
Optional
[
torch
.
dtype
]
=
None
):
indices_type
:
Optional
[
torch
.
dtype
]
=
None
,
enable_eplb
:
bool
=
False
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Route the input hidden states to the top-k experts based on the
router logits.
Returns:
(topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]):
The weights and *global physical* expert ids of the top-k experts.
**Compatibility**: When EPLB is not enabled, the returned ids are
equivalent to global logical ids, so should be compatible with
plain MoE implementations without redundant experts.
"""
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
# Dee
k
Seekv2 uses grouped_top_k
# Dee
p
Seekv2 uses grouped_top_k
if
use_grouped_topk
:
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
assert
num_expert_group
is
not
None
...
@@ -1330,6 +1448,74 @@ class FusedMoE(torch.nn.Module):
...
@@ -1330,6 +1448,74 @@ class FusedMoE(torch.nn.Module):
if
indices_type
is
not
None
:
if
indices_type
is
not
None
:
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
if
enable_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# TODO: maybe optimize this by using specified kernels,
# or compute pseudo-random indices by modulo
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long
=
topk_ids
.
long
()
replica_indices
=
(
torch
.
rand_like
(
topk_ids
,
dtype
=
torch
.
float
)
*
logical_replica_count
[
topk_ids_long
]).
long
().
unsqueeze
(
-
1
)
physical_ids
=
logical_to_physical_map
[
topk_ids_long
].
gather
(
-
1
,
replica_indices
).
squeeze
(
-
1
)
topk_ids
=
physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_logical_experts,)
# Mask out non-local experts
if
expert_map
is
not
None
:
topk_ids_local
=
expert_map
[
topk_ids
]
topk_ids_flatten
=
topk_ids_local
.
flatten
()
else
:
topk_ids_flatten
=
topk_ids
.
flatten
()
# Should be equivalent to:
# ```
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
# expert_load_view += topk_ids_masked.bincount(
# minlength=expert_load_view.shape[0])
# ```
# We use `scatter_add_` since `bincount` cannot be compiled
# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`
invalid_mask
=
topk_ids_flatten
<
0
# Replace invalid expert ids with 0 (just a dummy position)
# to avoid out-of-bounds errors in scatter_add_
index
=
topk_ids_flatten
.
masked_fill_
(
invalid_mask
,
0
)
# `src` is the valid mask, which is 1 for valid and 0 for invalid
src
=
~
invalid_mask
expert_load_view
.
scatter_add_
(
dim
=
0
,
index
=
index
.
long
(),
src
=
src
.
to
(
expert_load_view
))
topk_ids
=
topk_ids
.
to
(
dtype
=
indices_type
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
...
@@ -1410,6 +1596,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -1410,6 +1596,10 @@ class FusedMoE(torch.nn.Module):
scoring_func
=
self
.
scoring_func
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
enable_eplb
=
self
.
enable_eplb
,
expert_load_view
=
self
.
expert_load_view
,
logical_to_physical_map
=
self
.
logical_to_physical_map
,
logical_replica_count
=
self
.
logical_replica_count
,
)
)
if
not
skip_result_store
:
if
not
skip_result_store
:
...
@@ -1467,6 +1657,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -1467,6 +1657,10 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
=
self
.
e_score_correction_bias
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
enable_eplb
=
self
.
enable_eplb
,
expert_load_view
=
self
.
expert_load_view
,
logical_to_physical_map
=
self
.
logical_to_physical_map
,
logical_replica_count
=
self
.
logical_replica_count
,
)
)
if
do_naive_dispatch_combine
:
if
do_naive_dispatch_combine
:
...
@@ -1481,16 +1675,30 @@ class FusedMoE(torch.nn.Module):
...
@@ -1481,16 +1675,30 @@ class FusedMoE(torch.nn.Module):
@
classmethod
@
classmethod
def
make_expert_params_mapping
(
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
num_experts
:
int
,
num_redundant_experts
:
int
=
0
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
num_physical_experts
=
num_experts
+
num_redundant_experts
# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map
=
\
EplbState
.
build_initial_global_physical_to_logical_map
(
num_experts
,
num_redundant_experts
)
return
[
return
[
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_"
if
weight_name
(
"experts.w13_"
if
weight_name
in
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
else
"experts.w2_"
,
in
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
."
,
expert_id
,
shard_id
)
f
"experts.
{
physical_to_logical_map
[
expert_id
]
}
.
{
weight_name
}
."
,
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
[
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_physical_experts
)
for
shard_id
,
weight_name
in
[
(
"w1"
,
ckpt_gate_proj_name
),
(
"w1"
,
ckpt_gate_proj_name
),
(
"w2"
,
ckpt_down_proj_name
),
(
"w2"
,
ckpt_down_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
e9fd658a
...
@@ -482,7 +482,15 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -482,7 +482,15 @@ class AWQMoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `AWQMoEMethod` yet."
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
:
if
apply_router_weight_on_input
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
e9fd658a
...
@@ -331,7 +331,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -331,7 +331,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoEMethod` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -593,7 +601,15 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -593,7 +601,15 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoECutlassMethod` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -722,7 +738,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -722,7 +738,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet."
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
@@ -1012,7 +1037,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -1012,7 +1037,16 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsWNA16MarlinMoEMethod` yet."
)
assert
activation
==
"silu"
,
(
assert
activation
==
"silu"
,
(
f
"
{
activation
}
not supported for Marlin MoE."
)
f
"
{
activation
}
not supported for Marlin MoE."
)
assert
not
apply_router_weight_on_input
,
(
assert
not
apply_router_weight_on_input
,
(
...
@@ -1228,7 +1262,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1228,7 +1262,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsWNA16MoEMethod` yet."
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
e9fd658a
...
@@ -117,7 +117,15 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -117,7 +117,15 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
e9fd658a
...
@@ -825,7 +825,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -825,7 +825,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
assert
expert_load_view
is
not
None
assert
logical_to_physical_map
is
not
None
assert
logical_replica_count
is
not
None
assert
isinstance
(
layer
,
FusedMoE
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -839,6 +848,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -839,6 +848,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
)
if
self
.
rocm_aiter_moe_enabled
:
if
self
.
rocm_aiter_moe_enabled
:
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
e9fd658a
...
@@ -520,7 +520,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
...
@@ -520,7 +520,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `GGUFMoEMethod` yet."
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
:
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
e9fd658a
...
@@ -635,7 +635,15 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -635,7 +635,15 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `GPTQMarlinMoEMethod` yet."
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
:
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
e9fd658a
...
@@ -664,7 +664,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
...
@@ -664,7 +664,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
e9fd658a
...
@@ -297,7 +297,15 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -297,7 +297,15 @@ class MoeWNA16Method(FusedMoEMethodBase):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `MoeWNA16Method` yet."
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
e9fd658a
...
@@ -205,7 +205,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -205,7 +205,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet."
)
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment