Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
781096e3
Unverified
Commit
781096e3
authored
Feb 24, 2025
by
Jongseok Park
Committed by
GitHub
Feb 24, 2025
Browse files
Expert Parallelism (EP) Support for DeepSeek V2 (#12583)
parent
7940d8a6
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
527 additions
and
59 deletions
+527
-59
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+2
-1
tests/distributed/test_expert_parallel.py
tests/distributed/test_expert_parallel.py
+227
-0
tests/kernels/test_awq_marlin.py
tests/kernels/test_awq_marlin.py
+2
-7
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+59
-6
tests/kernels/utils.py
tests/kernels/utils.py
+3
-1
tests/utils.py
tests/utils.py
+3
-3
vllm/config.py
vllm/config.py
+20
-0
vllm/envs.py
vllm/envs.py
+7
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+96
-30
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+65
-5
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
+8
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+7
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+10
-0
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+4
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+4
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+2
-0
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+4
-0
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+4
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+0
-4
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
781096e3
...
@@ -468,7 +468,8 @@ def main(args: argparse.Namespace):
...
@@ -468,7 +468,8 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
):
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
...
...
tests/distributed/test_expert_parallel.py
0 → 100644
View file @
781096e3
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
List
,
Literal
,
NamedTuple
,
Optional
import
pytest
from
vllm.config
import
TaskOption
from
vllm.logger
import
init_logger
from
..utils
import
compare_two_settings
,
fork_new_process_for_each_test
logger
=
init_logger
(
"test_expert_parallel"
)
class
ParallelSetup
(
NamedTuple
):
tp_size
:
int
eager_mode
:
bool
chunked_prefill
:
bool
class
EPTestOptions
(
NamedTuple
):
trust_remote_code
:
bool
tokenizer_mode
:
Optional
[
str
]
load_format
:
Optional
[
str
]
=
None
hf_overrides
:
Optional
[
str
]
=
None
@
dataclass
class
EPTestSettings
:
parallel_setups
:
List
[
ParallelSetup
]
distributed_backends
:
List
[
str
]
task
:
TaskOption
test_options
:
EPTestOptions
@
staticmethod
def
detailed
(
*
,
tp_base
:
int
=
2
,
task
:
TaskOption
=
"auto"
,
trust_remote_code
:
bool
=
False
,
tokenizer_mode
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
hf_overrides
:
Optional
[
str
]
=
None
,
):
return
EPTestSettings
(
parallel_setups
=
[
ParallelSetup
(
tp_size
=
tp_base
,
eager_mode
=
False
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
tp_base
,
eager_mode
=
False
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
tp_base
,
eager_mode
=
True
,
chunked_prefill
=
False
),
ParallelSetup
(
tp_size
=
2
*
tp_base
,
eager_mode
=
False
,
chunked_prefill
=
True
),
ParallelSetup
(
tp_size
=
2
*
tp_base
,
eager_mode
=
True
,
chunked_prefill
=
False
),
],
distributed_backends
=
[
"mp"
,
"ray"
],
task
=
task
,
test_options
=
EPTestOptions
(
trust_remote_code
=
trust_remote_code
,
tokenizer_mode
=
tokenizer_mode
,
load_format
=
load_format
,
hf_overrides
=
hf_overrides
),
)
@
staticmethod
def
fast
(
*
,
tp_base
:
int
=
2
,
task
:
TaskOption
=
"auto"
,
trust_remote_code
:
bool
=
False
,
tokenizer_mode
:
Optional
[
str
]
=
None
,
load_format
:
Optional
[
str
]
=
None
,
hf_overrides
:
Optional
[
str
]
=
None
,
):
return
EPTestSettings
(
parallel_setups
=
[
ParallelSetup
(
tp_size
=
tp_base
,
eager_mode
=
True
,
chunked_prefill
=
False
),
],
distributed_backends
=
[
"mp"
],
task
=
task
,
test_options
=
EPTestOptions
(
trust_remote_code
=
trust_remote_code
,
tokenizer_mode
=
tokenizer_mode
,
load_format
=
load_format
,
hf_overrides
=
hf_overrides
),
)
def
iter_params
(
self
,
model_name
:
str
):
opts
=
self
.
test_options
for
parallel_setup
in
self
.
parallel_setups
:
for
distributed_backend
in
self
.
distributed_backends
:
yield
(
model_name
,
parallel_setup
,
distributed_backend
,
self
.
task
,
opts
)
# NOTE: You can adjust tp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model
# yapf: disable
TEST_MODELS
=
{
"deepseek-ai/DeepSeek-V2-Lite-Chat"
:
EPTestSettings
.
fast
(
trust_remote_code
=
True
),
"mistralai/Mixtral-8x7B-Instruct-v0.1"
:
EPTestSettings
.
fast
(
tp_base
=
4
),
}
def
_compare_tp
(
model_name
:
str
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
task
:
TaskOption
,
test_options
:
EPTestOptions
,
num_gpus_available
:
int
,
*
,
method
:
Literal
[
"generate"
],
):
(
tp_size
,
eager_mode
,
chunked_prefill
,
)
=
parallel_setup
(
trust_remote_code
,
tokenizer_mode
,
load_format
,
hf_overrides
,
)
=
test_options
if
num_gpus_available
<
tp_size
:
pytest
.
skip
(
f
"Need at least
{
tp_size
}
GPUs"
)
common_args
=
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"float16"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"8"
,
"--load-format"
,
"auto"
,
]
if
chunked_prefill
:
common_args
.
append
(
"--enable-chunked-prefill"
)
if
eager_mode
:
common_args
.
append
(
"--enforce-eager"
)
if
task
!=
"auto"
:
common_args
.
extend
([
"--task"
,
task
])
if
trust_remote_code
:
common_args
.
append
(
"--trust-remote-code"
)
if
tokenizer_mode
:
common_args
.
extend
([
"--tokenizer-mode"
,
tokenizer_mode
])
if
load_format
:
common_args
.
extend
([
"--load-format"
,
load_format
])
if
hf_overrides
:
common_args
.
extend
([
"--hf-overrides"
,
hf_overrides
])
ep_env
=
{
"VLLM_TEST_ENABLE_EP"
:
"1"
,
}
ep_args
=
[
*
common_args
,
"--tensor-parallel-size"
,
str
(
tp_size
),
"--distributed-executor-backend"
,
distributed_backend
,
]
# compare without expert parallelism
tp_env
=
{
"VLLM_TEST_ENABLE_EP"
:
"0"
,
}
tp_args
=
[
*
common_args
,
"--tensor-parallel-size"
,
str
(
tp_size
),
"--distributed-executor-backend"
,
"mp"
,
]
try
:
compare_two_settings
(
model_name
,
ep_args
,
tp_args
,
ep_env
,
tp_env
,
method
=
method
,
max_wait_seconds
=
360
)
except
Exception
:
raise
@
pytest
.
mark
.
parametrize
(
(
"model_name"
,
"parallel_setup"
,
"distributed_backend"
,
"task"
,
"test_options"
),
[
params
for
model_name
,
settings
in
TEST_MODELS
.
items
()
for
params
in
settings
.
iter_params
(
model_name
)
],
)
@
fork_new_process_for_each_test
def
test_ep
(
model_name
:
str
,
parallel_setup
:
ParallelSetup
,
distributed_backend
:
str
,
task
:
TaskOption
,
test_options
:
EPTestOptions
,
num_gpus_available
,
):
_compare_tp
(
model_name
,
parallel_setup
,
distributed_backend
,
task
,
test_options
,
num_gpus_available
,
method
=
"generate"
)
tests/kernels/test_awq_marlin.py
View file @
781096e3
...
@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
...
@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
num_bits
=
num_bits
,
num_bits
=
num_bits
,
)
)
torch_output
=
torch_moe
(
torch_output
=
torch_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
a
,
score
,
topk
,
None
)
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
score
,
topk
,
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
...
...
tests/kernels/test_moe.py
View file @
781096e3
...
@@ -26,6 +26,7 @@ from vllm.platforms import current_platform
...
@@ -26,6 +26,7 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
2
,
6
]
TOP_KS
=
[
2
,
6
]
...
@@ -34,6 +35,7 @@ TOP_KS = [2, 6]
...
@@ -34,6 +35,7 @@ TOP_KS = [2, 6]
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_fused_moe
(
def
test_fused_moe
(
m
:
int
,
m
:
int
,
...
@@ -41,6 +43,7 @@ def test_fused_moe(
...
@@ -41,6 +43,7 @@ def test_fused_moe(
k
:
int
,
k
:
int
,
e
:
int
,
e
:
int
,
topk
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -48,10 +51,38 @@ def test_fused_moe(
...
@@ -48,10 +51,38 @@ def test_fused_moe(
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randint
(
0
,
e
,
(
local_e
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1
=
w1
[
e_ids
]
w2
=
w2
[
e_ids
]
else
:
e_map
=
None
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
torch_output
,
atol
=
2e-2
,
atol
=
2e-2
,
...
@@ -63,13 +94,14 @@ def test_fused_moe(
...
@@ -63,13 +94,14 @@ def test_fused_moe(
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
weight_bits
:
int
):
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
if
has_zp
:
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
w_qzeros
[
expert_id
]
=
qzeros
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randint
(
0
,
e
,
(
local_e
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1_ref
=
w1_ref
[
e_ids
]
w2_ref
=
w2_ref
[
e_ids
]
w1_qweight
=
w1_qweight
[
e_ids
]
w2_qweight
=
w2_qweight
[
e_ids
]
w1_scales
=
w1_scales
[
e_ids
]
w2_scales
=
w2_scales
[
e_ids
]
w1_qzeros
=
w1_qzeros
[
e_ids
]
w2_qzeros
=
w2_qzeros
[
e_ids
]
else
:
e_map
=
None
triton_output
=
fused_moe
(
a
,
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w1_qweight
,
w2_qweight
,
w2_qweight
,
...
@@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize
=
False
,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int8_w8a16
=
weight_bits
==
8
,
global_num_experts
=
e
,
expert_map
=
e_map
,
w1_scale
=
w1_scales
,
w1_scale
=
w1_scales
,
w2_scale
=
w2_scales
,
w2_scale
=
w2_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
...
...
tests/kernels/utils.py
View file @
781096e3
...
@@ -1053,7 +1053,7 @@ def compute_max_diff(output, output_ref):
...
@@ -1053,7 +1053,7 @@ def compute_max_diff(output, output_ref):
torch
.
abs
(
output_ref
))
torch
.
abs
(
output_ref
))
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
expert_map
):
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
...
@@ -1061,6 +1061,8 @@ def torch_moe(a, w1, w2, score, topk):
...
@@ -1061,6 +1061,8 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
for
i
in
range
(
w1
.
shape
[
0
]):
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
mask
=
topk_ids
==
i
if
mask
.
sum
():
if
mask
.
sum
():
...
...
tests/utils.py
View file @
781096e3
...
@@ -297,12 +297,12 @@ def _test_completion_close(
...
@@ -297,12 +297,12 @@ def _test_completion_close(
logprobs
=
5
,
logprobs
=
5
,
temperature
=
0.0
)
temperature
=
0.0
)
logp
o
rbs
=
completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
]
logpr
o
bs
=
completion
.
choices
[
0
].
logprobs
.
top_logprobs
[
0
]
logp
o
rbs
=
{
k
:
round
(
v
,
2
)
for
k
,
v
in
logp
o
rbs
.
items
()}
logpr
o
bs
=
{
k
:
round
(
v
,
2
)
for
k
,
v
in
logpr
o
bs
.
items
()}
results
.
append
({
results
.
append
({
"test"
:
"completion_close"
,
"test"
:
"completion_close"
,
"logprobs"
:
logp
o
rbs
,
"logprobs"
:
logpr
o
bs
,
})
})
return
results
return
results
...
...
vllm/config.py
View file @
781096e3
...
@@ -677,6 +677,23 @@ class ModelConfig:
...
@@ -677,6 +677,23 @@ class ModelConfig:
"fallback to the eager mode."
)
"fallback to the eager mode."
)
self
.
enforce_eager
=
True
self
.
enforce_eager
=
True
def
_verify_with_expert_parallelism
(
self
)
->
None
:
num_expert_names
=
[
"moe_num_experts"
,
# Dbrx
"num_experts"
,
# Jamba
"n_routed_experts"
,
# DeepSeek
"num_local_experts"
,
# Mixtral
]
num_experts
=
0
for
name
in
num_expert_names
:
num_experts
=
getattr
(
self
.
hf_text_config
,
name
,
0
)
if
num_experts
>
0
:
break
if
num_experts
<
1
:
raise
ValueError
(
"Number of experts in the model must be greater than 0 "
"when expert parallelism is enabled."
)
def
verify_async_output_proc
(
self
,
parallel_config
,
speculative_config
,
def
verify_async_output_proc
(
self
,
parallel_config
,
speculative_config
,
device_config
)
->
None
:
device_config
)
->
None
:
if
not
self
.
use_async_output_proc
:
if
not
self
.
use_async_output_proc
:
...
@@ -730,6 +747,9 @@ class ModelConfig:
...
@@ -730,6 +747,9 @@ class ModelConfig:
" must be divisible by tensor parallel size "
" must be divisible by tensor parallel size "
f
"(
{
tensor_parallel_size
}
)."
)
f
"(
{
tensor_parallel_size
}
)."
)
if
envs
.
VLLM_TEST_ENABLE_EP
:
self
.
_verify_with_expert_parallelism
()
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
pipeline_parallel_size
=
parallel_config
.
pipeline_parallel_size
if
pipeline_parallel_size
>
1
:
if
pipeline_parallel_size
>
1
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
...
...
vllm/envs.py
View file @
781096e3
...
@@ -86,6 +86,7 @@ if TYPE_CHECKING:
...
@@ -86,6 +86,7 @@ if TYPE_CHECKING:
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
bool
=
True
VLLM_MLA_PERFORM_MATRIX_ABSORPTION
:
bool
=
True
VLLM_MLA_DISABLE_REQUANTIZATION
:
bool
=
False
VLLM_MLA_DISABLE_REQUANTIZATION
:
bool
=
False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE
:
bool
=
True
VLLM_TEST_ENABLE_EP
:
bool
=
False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
...
@@ -570,6 +571,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -570,6 +571,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
),
),
# If set, vLLM will use the experimental expert parallel implementation on
# the FusedMoE layer, using tensor parallelism size as expert parallelism
# size.
"VLLM_TEST_ENABLE_EP"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_TEST_ENABLE_EP"
,
"0"
))),
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
# so that users can colocate other actors on the same GPUs as vLLM.
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
781096e3
...
@@ -20,6 +20,18 @@ from vllm.utils import direct_register_custom_op
...
@@ -20,6 +20,18 @@ from vllm.utils import direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
):
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
compute_type
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel_gptq_awq
(
def
fused_moe_kernel_gptq_awq
(
# Pointers to matrices
# Pointers to matrices
...
@@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq(
...
@@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq(
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
use_int4_w4a16
:
if
use_int4_w4a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
\
stride_bn
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
elif
use_int8_w8a16
:
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
...
@@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq(
...
@@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq(
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsn
+
\
offs_bn
[
None
,
:]
*
stride_bsn
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsk
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
\
stride_bsk
b_scale
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_scale
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_scale
=
b_scale
.
to
(
tl
.
float32
)
b_scale
=
b_scale
.
to
(
tl
.
float32
)
...
@@ -319,13 +341,22 @@ def fused_moe_kernel(
...
@@ -319,13 +341,22 @@ def fused_moe_kernel(
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_int8_w8a16
:
if
use_int8_w8a16
:
...
@@ -349,7 +380,6 @@ def fused_moe_kernel(
...
@@ -349,7 +380,6 @@ def fused_moe_kernel(
# of fp32 values for higher accuracy.
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# Load the next block of A and B, generate a mask by checking the
# K dimension.
# K dimension.
...
@@ -544,8 +574,11 @@ def moe_align_block_size_triton(
...
@@ -544,8 +574,11 @@ def moe_align_block_size_triton(
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
block_size
:
int
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns the token distribution across experts to be compatible with block
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
size for matrix multiplication.
...
@@ -555,6 +588,10 @@ def moe_align_block_size(
...
@@ -555,6 +588,10 @@ def moe_align_block_size(
top-k expert indices for each token.
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
Returns:
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
- sorted_token_ids: A tensor containing the sorted token indices according
...
@@ -589,7 +626,9 @@ def moe_align_block_size(
...
@@ -589,7 +626,9 @@ def moe_align_block_size(
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids
=
torch
.
zeros
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
num_tokens_post_pad
=
torch
.
empty
((
1
),
...
@@ -618,6 +657,9 @@ def moe_align_block_size(
...
@@ -618,6 +657,9 @@ def moe_align_block_size(
else
:
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
,
num_tokens_post_pad
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
...
@@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake(
...
@@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake(
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1049,6 +1096,8 @@ def outplace_fused_experts(
...
@@ -1049,6 +1096,8 @@ def outplace_fused_experts(
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1058,8 +1107,9 @@ def outplace_fused_experts(
...
@@ -1058,8 +1107,9 @@ def outplace_fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
a1_scale
,
a2_scale
,
block_shape
)
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
outplace_fused_experts_fake
(
def
outplace_fused_experts_fake
(
...
@@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake(
...
@@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake(
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
if
inplace
:
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
torch
.
ops
.
vllm
.
inplace_fused_experts
(
topk_weights
,
topk_ids
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
use_int4_w4a16
,
w1_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
return
hidden_states
return
hidden_states
else
:
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_z
p
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_ma
p
,
a1_scale
,
a2_scale
,
block_shape
)
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
@@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
num_tokens
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
shape
[
1
]
# We execute the fused_moe kernel in chunks to circumvent this issue:
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
...
@@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
w1
.
shape
,
w1
.
shape
,
w2
.
shape
,
w2
.
shape
,
topk_
ids
.
shape
[
1
]
,
top
_
k_
num
,
config_dtype
,
config_dtype
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
config
=
get_config_func
(
M
)
config
=
get_config_func
(
M
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_
ids
.
shape
[
1
]
,
N
),
intermediate_cache1
=
torch
.
empty
((
M
,
top
_
k_
num
,
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_
ids
.
shape
[
1
]
,
N
//
2
),
intermediate_cache2
=
torch
.
empty
((
M
*
top
_
k_
num
,
N
//
2
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk_
ids
.
shape
[
1
]
,
w2
.
shape
[
1
]),
intermediate_cache3
=
torch
.
empty
((
M
,
top
_
k_
num
,
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
...
@@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
w1
,
...
@@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
expert_ids
,
expert_ids
,
num_tokens_post_padded
,
num_tokens_post_padded
,
False
,
False
,
topk_
ids
.
shape
[
1
]
,
top
_
k_
num
,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
...
@@ -1286,6 +1345,8 @@ def fused_moe(
...
@@ -1286,6 +1345,8 @@ def fused_moe(
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1320,6 +1381,11 @@ def fused_moe(
...
@@ -1320,6 +1381,11 @@ def fused_moe(
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
activation to compute the inner products for w1 and w2.
Defaults to False.
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
...
@@ -1334,8 +1400,6 @@ def fused_moe(
...
@@ -1334,8 +1400,6 @@ def fused_moe(
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
"""
"""
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
use_grouped_topk
:
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
...
@@ -1358,6 +1422,8 @@ def fused_moe(
...
@@ -1358,6 +1422,8 @@ def fused_moe(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w1_zp
=
w1_zp
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
781096e3
...
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Tuple
...
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Tuple
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
...
@@ -55,6 +56,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -55,6 +56,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -113,6 +116,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -113,6 +116,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -125,6 +130,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -125,6 +130,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk
=
use_grouped_topk
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
...
@@ -139,6 +146,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -139,6 +146,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -160,7 +169,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -160,7 +169,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
)
inplace
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
def
forward_cpu
(
def
forward_cpu
(
self
,
self
,
...
@@ -172,6 +183,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -172,6 +183,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -196,6 +209,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -196,6 +209,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -215,6 +230,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -215,6 +230,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
topk
=
top_k
,
topk
=
top_k
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
renormalize
=
renormalize
)
renormalize
=
renormalize
)
forward_native
=
forward_cuda
forward_native
=
forward_cuda
...
@@ -255,6 +272,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -255,6 +272,7 @@ class FusedMoE(torch.nn.Module):
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
ep_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
...
@@ -267,8 +285,13 @@ class FusedMoE(torch.nn.Module):
...
@@ -267,8 +285,13 @@ class FusedMoE(torch.nn.Module):
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
if
envs
.
VLLM_TEST_ENABLE_EP
:
self
.
ep_size
=
self
.
tp_size
self
.
tp_size
=
1
else
:
self
.
ep_size
=
1
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
# Global number of experts
assert
intermediate_size
%
self
.
tp_size
==
0
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
...
@@ -281,6 +304,26 @@ class FusedMoE(torch.nn.Module):
...
@@ -281,6 +304,26 @@ class FusedMoE(torch.nn.Module):
self
.
custom_routing_function
=
custom_routing_function
self
.
custom_routing_function
=
custom_routing_function
self
.
scoring_func
=
scoring_func
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
expert_map
=
None
if
self
.
ep_size
>
1
:
# Create a tensor of size num_experts filled with -1
self
.
expert_map
=
torch
.
full
((
self
.
num_experts
,
),
-
1
,
dtype
=
torch
.
int32
)
# Create a expert map for the local experts
local_num_experts
=
num_experts
//
self
.
ep_size
ep_rank
=
get_tensor_model_parallel_rank
()
if
ep_rank
<
(
self
.
ep_size
-
1
):
# Each non-last rank gets local_num_experts experts.
self
.
expert_map
[
ep_rank
*
local_num_experts
:
(
ep_rank
+
1
)
*
local_num_experts
]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
else
:
# All remaining experts are assigned to the last rank.
local_num_experts
=
num_experts
-
ep_rank
*
local_num_experts
self
.
expert_map
[
-
local_num_experts
:]
=
\
torch
.
arange
(
0
,
local_num_experts
,
dtype
=
torch
.
int32
)
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
raise
ValueError
(
"Only softmax scoring function is supported for "
...
@@ -293,8 +336,11 @@ class FusedMoE(torch.nn.Module):
...
@@ -293,8 +336,11 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
local_num_experts
=
torch
.
sum
(
self
.
expert_map
!=
-
1
)
\
if
self
.
expert_map
is
not
None
else
num_experts
moe_quant_params
=
{
moe_quant_params
=
{
"num_experts"
:
num_experts
,
"num_experts"
:
local_
num_experts
,
"hidden_size"
:
hidden_size
,
"hidden_size"
:
hidden_size
,
"intermediate_size_per_partition"
:
"intermediate_size_per_partition"
:
self
.
intermediate_size_per_partition
,
self
.
intermediate_size_per_partition
,
...
@@ -423,10 +469,22 @@ class FusedMoE(torch.nn.Module):
...
@@ -423,10 +469,22 @@ class FusedMoE(torch.nn.Module):
assert
shard_id
in
(
"w1"
,
"w3"
)
assert
shard_id
in
(
"w1"
,
"w3"
)
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
def
_map_global_expert_id_to_local_expert_id
(
self
,
expert_id
:
int
)
->
int
:
if
self
.
expert_map
is
None
:
return
expert_id
return
self
.
expert_map
[
expert_id
].
item
()
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
expert_id
=
self
.
_map_global_expert_id_to_local_expert_id
(
expert_id
)
if
expert_id
==
-
1
:
return
# TP rank is set to 0 if EP is enabled
tp_rank
=
0
if
self
.
ep_size
>
1
else
get_tensor_model_parallel_rank
()
# compressed-tensors checkpoints with packed weights are stored flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
# against known CompressionFormat enum values that have this quality
...
@@ -447,7 +505,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -447,7 +505,6 @@ class FusedMoE(torch.nn.Module):
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed: if the dim to shard the weight
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be flipped. Required by GPTQ, compressed-tensors
...
@@ -590,13 +647,16 @@ class FusedMoE(torch.nn.Module):
...
@@ -590,13 +647,16 @@ class FusedMoE(torch.nn.Module):
top_k
=
self
.
top_k
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
use_grouped_topk
=
self
.
use_grouped_topk
,
global_num_experts
=
self
.
num_experts
,
expert_map
=
self
.
expert_map
,
topk_group
=
self
.
topk_group
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
custom_routing_function
=
self
.
custom_routing_function
,
scoring_func
=
self
.
scoring_func
,
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
)
e_score_correction_bias
=
self
.
e_score_correction_bias
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
# Default set to False. (May have to add shared expert outputs.)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
View file @
781096e3
...
@@ -10,7 +10,9 @@ def fused_moe(
...
@@ -10,7 +10,9 @@ def fused_moe(
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
=
None
,
renormalize
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -18,6 +20,7 @@ def fused_moe(
...
@@ -18,6 +20,7 @@ def fused_moe(
w1: [num_experts, intermediate_size * 2, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
gating_output: [*, num_experts]
expert_map: [num_experts]
"""
"""
orig_shape
=
hidden_states
.
shape
orig_shape
=
hidden_states
.
shape
hidden_size
=
hidden_states
.
shape
[
-
1
]
hidden_size
=
hidden_states
.
shape
[
-
1
]
...
@@ -27,13 +30,16 @@ def fused_moe(
...
@@ -27,13 +30,16 @@ def fused_moe(
dtype
=
hidden_states
.
dtype
dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
view
(
num_tokens
,
hidden_size
)
hidden_states
=
hidden_states
.
view
(
num_tokens
,
hidden_size
)
gating_output
=
gating_output
.
view
(
num_tokens
,
num_experts
)
gating_output
=
gating_output
.
view
(
num_tokens
,
global_
num_experts
)
topk_weights
=
gating_output
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float
)
topk_weights
=
gating_output
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float
)
topk_weights
,
selected_experts
=
topk_weights
.
topk
(
topk
,
dim
=-
1
)
topk_weights
,
selected_experts
=
topk_weights
.
topk
(
topk
,
dim
=-
1
)
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
.
to
(
dtype
)
topk_weights
=
topk_weights
.
to
(
dtype
)
if
expert_map
is
not
None
:
selected_experts
=
expert_map
[
selected_experts
]
final_hidden_states
=
None
final_hidden_states
=
None
for
expert_idx
in
range
(
num_experts
):
for
expert_idx
in
range
(
num_experts
):
expert_w1
=
w1
[
expert_idx
]
expert_w1
=
w1
[
expert_idx
]
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
781096e3
...
@@ -464,10 +464,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -464,10 +464,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
expert_map
is
not
None
:
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
781096e3
...
@@ -214,6 +214,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -214,6 +214,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -239,6 +241,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -239,6 +241,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
...
@@ -540,10 +544,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -540,10 +544,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
expert_map
is
not
None
:
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
781096e3
...
@@ -108,6 +108,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -108,6 +108,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -133,6 +135,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
...
@@ -133,6 +135,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_int8_w8a16
=
True
,
use_int8_w8a16
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scale
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
w2_scale
=
layer
.
w2_scale
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
781096e3
...
@@ -670,6 +670,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -670,6 +670,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -697,6 +699,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -697,6 +699,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale_inv
w2_scale
=
(
layer
.
w2_weight_scale_inv
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
781096e3
...
@@ -585,6 +585,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -585,6 +585,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
781096e3
...
@@ -288,6 +288,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -288,6 +288,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -317,6 +319,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -317,6 +319,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
inplace
=
True
,
inplace
=
True
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int8_w8a16
=
weight_bits
==
8
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scales
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
w2_scale
=
layer
.
w2_scales
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
781096e3
...
@@ -198,6 +198,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -198,6 +198,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -223,6 +225,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
...
@@ -223,6 +225,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
781096e3
...
@@ -106,10 +106,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -106,10 +106,6 @@ class DeepseekV2MoE(nn.Module):
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
if
self
.
tp_size
>
config
.
n_routed_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
n_routed_experts
}
."
)
if
config
.
hidden_act
!=
"silu"
:
if
config
.
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment