Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41deac4a
Unverified
Commit
41deac4a
authored
Mar 24, 2024
by
Nick Hill
Committed by
GitHub
Mar 24, 2024
Browse files
[BugFix] 1D query fix for MoE models (#3597)
parent
af9e5349
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
15 deletions
+15
-15
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+6
-4
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+3
-4
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+3
-4
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+3
-3
No files found.
tests/kernels/test_moe.py
View file @
41deac4a
...
...
@@ -81,11 +81,13 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_moe
.
w2s
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
hf_inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
inputs
)
vllm_states
=
vllm_moe
.
forward
(
inputs
)
hf_states
,
_
=
hf_moe
.
forward
(
hf_
inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_
inputs
)
mixtral_moe_tol
=
{
torch
.
float32
:
1e-3
,
...
...
@@ -93,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
torch
.
bfloat16
:
1e-2
,
}
assert
torch
.
allclose
(
hf_states
,
assert
torch
.
allclose
(
hf_states
.
flatten
(
0
,
1
)
,
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
vllm/model_executor/models/deepseek.py
View file @
41deac4a
...
...
@@ -150,11 +150,11 @@ class DeepseekMoE(nn.Module):
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
config
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (
batch * sequence_length
, n_experts)
# router_logits: (
num_tokens
, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
...
...
@@ -169,8 +169,7 @@ class DeepseekMoE(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
DeepseekAttention
(
nn
.
Module
):
...
...
vllm/model_executor/models/mixtral.py
View file @
41deac4a
...
...
@@ -124,9 +124,9 @@ class MixtralMoE(nn.Module):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_size
=
hidden_states
.
shape
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (
batch * sequence_length
, n_experts)
# router_logits: (
num_tokens
, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
...
...
@@ -140,8 +140,7 @@ class MixtralMoE(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_size
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
class
MixtralAttention
(
nn
.
Module
):
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
41deac4a
...
...
@@ -132,9 +132,9 @@ class MixtralMoE(nn.Module):
linear_method
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (
batch * sequence_length
, n_experts)
# router_logits: (
num_tokens
, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
...
...
@@ -158,7 +158,7 @@ class MixtralMoE(nn.Module):
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
).
view
(
batch_size
,
sequence_length
,
hidden_dim
)
num_tokens
,
hidden_dim
)
class
MixtralAttention
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment