Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d0d93b92
Unverified
Commit
d0d93b92
authored
Jan 31, 2024
by
Philipp Moritz
Committed by
GitHub
Jan 31, 2024
Browse files
Add unit test for Mixtral MoE layer (#2677)
parent
89efcf1c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
119 additions
and
5 deletions
+119
-5
Dockerfile
Dockerfile
+6
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+104
-0
vllm/model_executor/layers/fused_moe.py
vllm/model_executor/layers/fused_moe.py
+3
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+6
-4
No files found.
Dockerfile
View file @
d0d93b92
...
@@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
...
@@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
RUN
apt-get update
-y
\
RUN
apt-get update
-y
\
&&
apt-get
install
-y
python3-pip git
&&
apt-get
install
-y
python3-pip git
# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN
ldconfig /usr/local/cuda-12.1/compat/
WORKDIR
/workspace
WORKDIR
/workspace
# install build and runtime dependencies
# install build and runtime dependencies
...
...
tests/kernels/test_
fused_
moe.py
→
tests/kernels/test_moe.py
View file @
d0d93b92
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_moe.py`.
"""
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.models.mixtral
import
MixtralMoE
def
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
):
def
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
):
...
@@ -48,3 +57,48 @@ def test_fused_moe(
...
@@ -48,3 +57,48 @@ def test_fused_moe(
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
):
"Make sure our Mixtral MoE implementation agrees with the one from huggingface."
# Instantiate our and huggingface's MoE blocks
config
=
MixtralConfig
()
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
vllm_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
)
# Load the weights
vllm_moe
.
gate
.
linear_weights
[
"weight"
][:]
=
hf_moe
.
gate
.
weight
.
data
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
ws
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
w2s
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
inputs
)
vllm_states
=
vllm_moe
.
forward
(
inputs
)
mixtral_moe_tol
=
{
torch
.
float32
:
1e-3
,
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
}
assert
torch
.
allclose
(
hf_states
,
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
vllm/model_executor/layers/fused_moe.py
View file @
d0d93b92
...
@@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor,
...
@@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor,
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
...
...
vllm/model_executor/models/mixtral.py
View file @
d0d93b92
...
@@ -70,13 +70,14 @@ class MixtralMoE(nn.Module):
...
@@ -70,13 +70,14 @@ class MixtralMoE(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
num_experts
self
.
num_total_experts
=
num_experts
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
//
tp_size
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
...
@@ -141,8 +142,9 @@ class MixtralMoE(nn.Module):
...
@@ -141,8 +142,9 @@ class MixtralMoE(nn.Module):
selected_experts
,
selected_experts
,
inplace
=
True
)
inplace
=
True
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
if
self
.
tp_size
>
1
:
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_size
)
hidden_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment