Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
2e29dacf
Commit
2e29dacf
authored
Dec 24, 2023
by
Tri Dao
Browse files
Implement muParam
parent
3f7d5786
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
4 deletions
+24
-4
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+24
-4
No files found.
flash_attn/models/gpt.py
View file @
2e29dacf
...
@@ -77,7 +77,9 @@ logger = logging.getLogger(__name__)
...
@@ -77,7 +77,9 @@ logger = logging.getLogger(__name__)
def
create_mixer_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
create_mixer_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
softmax_scale
=
1.0
if
not
config
.
scale_attn_weights
else
head_dim
**
(
-
0.5
)
attn_scale_power
=
0.5
if
not
getattr
(
config
,
"mup_scale_qk_dot_by_d"
,
False
)
else
1.0
softmax_scale
=
1.0
if
not
config
.
scale_attn_weights
else
(
head_dim
**
(
-
attn_scale_power
))
softmax_scale
*=
getattr
(
config
,
"mup_attn_multiplier"
,
1.0
)
if
config
.
scale_attn_by_inverse_layer_idx
:
if
config
.
scale_attn_by_inverse_layer_idx
:
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
softmax_scale
/=
float
(
layer_idx
+
1
)
softmax_scale
/=
float
(
layer_idx
+
1
)
...
@@ -179,12 +181,14 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -179,12 +181,14 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
process_group
is
not
None
if
process_group
is
not
None
else
{}
else
{}
)
)
mlp_multiple_of
=
getattr
(
config
,
"mlp_multiple_of"
,
128
)
mlp_cls
=
partial
(
mlp_cls
=
partial
(
mlp_cls
,
mlp_cls
,
hidden_features
=
config
.
n_inner
,
hidden_features
=
config
.
n_inner
,
activation
=
activation
,
activation
=
activation
,
bias1
=
mlp_fc1_bias
,
bias1
=
mlp_fc1_bias
,
bias2
=
mlp_fc2_bias
,
bias2
=
mlp_fc2_bias
,
multiple_of
=
mlp_multiple_of
,
**
parallel_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
,
**
factory_kwargs
,
)
)
...
@@ -386,9 +390,13 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -386,9 +390,13 @@ class GPTPreTrainedModel(nn.Module):
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
rescale_prenorm_residual
=
True
):
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
mup_width_scale
=
1.0
,
rescale_prenorm_residual
=
True
):
mup_init_scale
=
math
.
sqrt
(
mup_width_scale
)
if
isinstance
(
module
,
nn
.
Linear
):
if
isinstance
(
module
,
nn
.
Linear
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
*
mup_init_scale
)
module
.
weight
.
_optim
=
{
"lr_multiplier"
:
mup_width_scale
}
if
module
.
bias
is
not
None
:
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
elif
isinstance
(
module
,
nn
.
Embedding
):
...
@@ -404,7 +412,9 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
...
@@ -404,7 +412,9 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
for
name
,
p
in
module
.
named_parameters
():
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
2
*
n_layer
))
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
*
mup_init_scale
/
math
.
sqrt
(
2
*
n_layer
)
)
class
GPTModel
(
GPTPreTrainedModel
):
class
GPTModel
(
GPTPreTrainedModel
):
...
@@ -429,6 +439,7 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -429,6 +439,7 @@ class GPTModel(GPTPreTrainedModel):
vocab_size
=
(
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
)
self
.
embeddings_multiplier
=
getattr
(
config
,
"mup_embeddings_multiplier"
,
1.0
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self
.
residual_in_fp32
=
getattr
(
config
,
"residual_in_fp32"
,
False
)
self
.
residual_in_fp32
=
getattr
(
config
,
"residual_in_fp32"
,
False
)
# These 2 options are for OPT-350m
# These 2 options are for OPT-350m
...
@@ -494,6 +505,7 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -494,6 +505,7 @@ class GPTModel(GPTPreTrainedModel):
_init_weights
,
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
,
initializer_range
=
config
.
initializer_range
,
mup_width_scale
=
getattr
(
config
,
"mup_width_scale"
,
1.0
),
)
)
)
)
self
.
tie_weights
()
self
.
tie_weights
()
...
@@ -518,6 +530,8 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -518,6 +530,8 @@ class GPTModel(GPTPreTrainedModel):
else
{}
else
{}
)
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
if
self
.
embeddings_multiplier
!=
1.0
:
hidden_states
=
hidden_states
*
self
.
embeddings_multiplier
if
self
.
parallel_block
:
if
self
.
parallel_block
:
hidden_states2
=
None
hidden_states2
=
None
residual
=
None
residual
=
None
...
@@ -612,6 +626,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -612,6 +626,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self
.
project_out
=
nn
.
Linear
(
config
.
n_embd
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
project_out
=
nn
.
Linear
(
config
.
n_embd
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
else
:
else
:
self
.
project_out
=
None
self
.
project_out
=
None
mup_width_scale
=
getattr
(
config
,
"mup_width_scale"
,
1.0
)
mup_output_multiplier
=
getattr
(
config
,
"mup_output_multiplier"
,
1.0
)
self
.
output_scale
=
mup_output_multiplier
*
mup_width_scale
if
process_group
is
None
:
if
process_group
is
None
:
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
lm_head_bias
,
**
factory_kwargs
)
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
lm_head_bias
,
**
factory_kwargs
)
else
:
else
:
...
@@ -632,6 +649,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -632,6 +649,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
_init_weights
,
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
,
initializer_range
=
config
.
initializer_range
,
mup_width_scale
=
mup_width_scale
,
)
)
)
)
self
.
tie_weights
()
self
.
tie_weights
()
...
@@ -667,6 +685,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -667,6 +685,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
if
self
.
project_out
is
not
None
:
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
hidden_states
=
self
.
project_out
(
hidden_states
)
if
self
.
output_scale
!=
1.0
:
hidden_states
=
hidden_states
*
self
.
output_scale
if
not
self
.
norm_head
:
if
not
self
.
norm_head
:
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment