Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e72ae80b
"vscode:/vscode.git/clone" did not exist on "12a223ef9bfebcc61e477047dce049495fe8c8a8"
Unverified
Commit
e72ae80b
authored
Jul 10, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 10, 2024
Browse files
[Bugfix] Support 2D input shape in MoE layer (#6287)
parent
8a924d22
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
4 deletions
+7
-4
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+3
-2
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+4
-2
No files found.
vllm/model_executor/models/mixtral.py
View file @
e72ae80b
...
@@ -88,12 +88,13 @@ class MixtralMoE(nn.Module):
...
@@ -88,12 +88,13 @@ class MixtralMoE(nn.Module):
tp_size
=
tp_size
)
tp_size
=
tp_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
router_logits
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_siz
e
)
return
final_hidden_states
.
view
(
orig_shap
e
)
class
MixtralAttention
(
nn
.
Module
):
class
MixtralAttention
(
nn
.
Module
):
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
e72ae80b
...
@@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -126,7 +126,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
bias
=
False
)
bias
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
shared_output
=
None
if
self
.
shared_expert
is
not
None
:
if
self
.
shared_expert
is
not
None
:
...
@@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -145,7 +147,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
orig_shape
)
class
Qwen2MoeAttention
(
nn
.
Module
):
class
Qwen2MoeAttention
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment