Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d1f6d1c8
Unverified
Commit
d1f6d1c8
authored
Dec 10, 2024
by
Isotr0py
Committed by
GitHub
Dec 10, 2024
Browse files
[Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
6d525288
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
12 deletions
+34
-12
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+9
-2
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+18
-8
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+7
-2
No files found.
vllm/model_executor/layers/layernorm.py
View file @
d1f6d1c8
...
@@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
...
@@ -20,6 +20,7 @@ class RMSNorm(CustomOp):
hidden_size
:
int
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
var_hidden_size
:
Optional
[
int
]
=
None
,
var_hidden_size
:
Optional
[
int
]
=
None
,
has_weight
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
...
@@ -27,7 +28,11 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
else
var_hidden_size
)
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
has_weight
=
has_weight
self
.
weight
=
torch
.
ones
(
hidden_size
)
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
def
forward_native
(
def
forward_native
(
self
,
self
,
...
@@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
...
@@ -59,7 +64,9 @@ class RMSNorm(CustomOp):
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
x
=
x
.
to
(
orig_dtype
)
if
self
.
has_weight
:
x
=
x
*
self
.
weight
if
residual
is
None
:
if
residual
is
None
:
return
x
return
x
else
:
else
:
...
...
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
d1f6d1c8
...
@@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
...
@@ -40,6 +40,7 @@ class MambaMixer(CustomOp):
use_conv_bias
:
bool
,
use_conv_bias
:
bool
,
use_bias
:
bool
,
use_bias
:
bool
,
use_rms_norm
:
bool
,
use_rms_norm
:
bool
,
rms_norm_has_weight
:
bool
=
True
,
rms_norm_eps
:
float
=
1e-5
,
rms_norm_eps
:
float
=
1e-5
,
activation
=
"silu"
):
activation
=
"silu"
):
super
().
__init__
()
super
().
__init__
()
...
@@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
...
@@ -105,14 +106,23 @@ class MambaMixer(CustomOp):
input_is_parallel
=
True
,
input_is_parallel
=
True
,
)
)
self
.
dt_layernorm
=
RMSNorm
(
time_step_rank
,
self
.
dt_layernorm
=
RMSNorm
(
eps
=
rms_norm_eps
)
if
use_rms_norm
else
None
time_step_rank
,
eps
=
rms_norm_eps
,
has_weight
=
rms_norm_has_weight
,
)
if
use_rms_norm
else
None
self
.
b_layernorm
=
RMSNorm
(
ssm_state_size
,
self
.
b_layernorm
=
RMSNorm
(
eps
=
rms_norm_eps
)
if
use_rms_norm
else
None
ssm_state_size
,
eps
=
rms_norm_eps
,
has_weight
=
rms_norm_has_weight
,
)
if
use_rms_norm
else
None
self
.
c_layernorm
=
RMSNorm
(
ssm_state_size
,
self
.
c_layernorm
=
RMSNorm
(
eps
=
rms_norm_eps
)
if
use_rms_norm
else
None
ssm_state_size
,
eps
=
rms_norm_eps
,
has_weight
=
rms_norm_has_weight
,
)
if
use_rms_norm
else
None
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
...
...
vllm/model_executor/models/mamba.py
View file @
d1f6d1c8
"""PyTorch MAMBA model."""
"""PyTorch MAMBA model."""
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module):
...
@@ -47,6 +47,7 @@ class MambaDecoderLayer(nn.Module):
use_conv_bias
=
config
.
use_conv_bias
,
use_conv_bias
=
config
.
use_conv_bias
,
use_bias
=
config
.
use_bias
,
use_bias
=
config
.
use_bias
,
use_rms_norm
=
self
.
is_falcon_mamba
,
use_rms_norm
=
self
.
is_falcon_mamba
,
rms_norm_has_weight
=
not
self
.
is_falcon_mamba
,
rms_norm_eps
=
mixer_rms_eps
,
rms_norm_eps
=
mixer_rms_eps
,
activation
=
config
.
hidden_act
)
activation
=
config
.
hidden_act
)
...
@@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -241,8 +242,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"A_log"
in
name
:
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
name
=
name
.
replace
(
"A_log"
,
"A"
)
...
@@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
...
@@ -254,3 +257,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment