Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c68b5c63
Unverified
Commit
c68b5c63
authored
May 29, 2025
by
rongfu.leng
Committed by
GitHub
May 28, 2025
Browse files
[Misc] fix olmoe model layer can't laod in tp gt 1 (#18828)
Signed-off-by:
rongfu.leng
<
rongfu.leng@daocloud.io
>
parent
fced7569
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
4 deletions
+25
-4
vllm/model_executor/models/olmoe.py
vllm/model_executor/models/olmoe.py
+25
-4
No files found.
vllm/model_executor/models/olmoe.py
View file @
c68b5c63
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
functools
import
partial
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
import
torch
...
@@ -22,7 +23,10 @@ from transformers import PretrainedConfig
...
@@ -22,7 +23,10 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.distributed.utils
import
split_tensor_along_last_dim
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
...
@@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
q_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
self
.
tp_size
=
tp_size
self
.
k_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-5
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
q_norm
=
RMSNorm
(
self
.
total_num_heads
*
self
.
head_dim
,
eps
=
1e-5
)
self
.
k_norm
=
RMSNorm
(
self
.
total_num_kv_heads
*
self
.
head_dim
,
eps
=
1e-5
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
...
@@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
...
@@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
())
q
=
self
.
q_norm
(
q
)
k
=
self
.
k_norm
(
k
)
if
self
.
tp_size
>
1
:
splitter
=
partial
(
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
return
q
,
k
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
...
@@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
q_norm
(
q
.
contiguous
()),
self
.
k_norm
(
k
.
contiguous
()
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment