Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54951ac4
Unverified
Commit
54951ac4
authored
Apr 06, 2024
by
Isotr0py
Committed by
GitHub
Apr 05, 2024
Browse files
[Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (#3869)
parent
18de8834
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
19 deletions
+12
-19
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+12
-19
No files found.
vllm/model_executor/models/olmo.py
View file @
54951ac4
...
...
@@ -39,14 +39,15 @@
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
# this model must need this dependency
from
hf_olmo
import
OLMoConfig
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from
vllm.sequence
import
SamplerOutput
class
SwiGLU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
gate
=
x
.
chunk
(
2
,
dim
=-
1
)
return
F
.
silu
(
gate
)
*
x
@
property
def
output_multiplier
(
self
)
->
float
:
return
0.5
class
OlmoAttention
(
nn
.
Module
):
"""
This is the attention block where the output is computed as
...
...
@@ -174,17 +164,16 @@ class OlmoMLP(nn.Module):
bias
=
False
)
# Feed-forward input projection.
self
.
ff_proj
=
ColumnParallelLinear
(
self
.
ff_proj
=
Merged
ColumnParallelLinear
(
config
.
d_model
,
self
.
hidden_size
,
[
self
.
hidden_size
//
2
]
*
2
,
bias
=
config
.
include_bias
,
linear_method
=
linear_method
,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self
.
act
=
SwiGLU
()
self
.
act
=
SiluAndMul
()
self
.
act
.
output_multiplier
=
0.5
assert
(
self
.
act
.
output_multiplier
*
self
.
hidden_size
)
%
1
==
0
# Feed-forward output projection.
...
...
@@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module):
if
".att"
in
name
:
name
=
name
.
replace
(
".att"
,
".attn.att"
)
# mlp
if
".ff"
in
name
and
"transformer.ff_out"
not
in
name
:
name
=
name
.
replace
(
".ff"
,
".mlp.ff"
)
if
".ff_proj"
in
name
:
name
=
name
.
replace
(
".ff_proj"
,
".mlp.ff_proj"
)
# Reverse the weight for the MergeColumnParallelLinear
loaded_weight
=
torch
.
concat
(
loaded_weight
.
chunk
(
2
)[::
-
1
])
if
".ff_out"
in
name
and
"transformer.ff_out"
not
in
name
:
name
=
name
.
replace
(
".ff_out"
,
".mlp.ff_out"
)
# there is no bias in olmo
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment