Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
73df3be7
Commit
73df3be7
authored
Dec 25, 2023
by
Tri Dao
Browse files
Add test for BTLM init
parent
7ffba9a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
14 deletions
+47
-14
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+3
-1
tests/models/test_btlm.py
tests/models/test_btlm.py
+44
-13
No files found.
flash_attn/models/gpt.py
View file @
73df3be7
...
@@ -396,7 +396,9 @@ def _init_weights(
...
@@ -396,7 +396,9 @@ def _init_weights(
mup_init_scale
=
math
.
sqrt
(
mup_width_scale
)
mup_init_scale
=
math
.
sqrt
(
mup_width_scale
)
if
isinstance
(
module
,
nn
.
Linear
):
if
isinstance
(
module
,
nn
.
Linear
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
*
mup_init_scale
)
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
*
mup_init_scale
)
module
.
weight
.
_optim
=
{
"lr_multiplier"
:
mup_width_scale
}
optim_cfg
=
getattr
(
module
.
weight
,
"_optim"
,
{})
optim_cfg
.
update
({
"lr_multiplier"
:
mup_width_scale
})
setattr
(
module
.
weight
,
"_optim"
,
optim_cfg
)
if
module
.
bias
is
not
None
:
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
elif
isinstance
(
module
,
nn
.
Embedding
):
...
...
tests/models/test_btlm.py
View file @
73df3be7
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
import
os
import
time
import
time
from
pathlib
import
Path
import
torch
import
torch
import
pytest
import
pytest
from
einops
import
rearrange
from
transformers
import
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
from
transformers
import
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
...
@@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
...
@@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
config
=
btlm_config_to_gpt2_config
(
config
=
btlm_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
)
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
...
@@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
...
@@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
.
eval
()
...
@@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
...
@@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
del
model_ref
del
model_ref
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
.
eval
()
...
@@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
...
@@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
assert
torch
.
equal
(
logits_cg
,
logits
)
assert
torch
.
equal
(
logits_cg
,
logits
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"cerebras/btlm-3b-8k-base"
])
def
test_btlm_init
(
model_name
):
dtype
=
torch
.
float32
device
=
"cuda"
btlm_config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
btlm_config_to_gpt2_config
(
btlm_config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model_ref
=
AutoModelForCausalLM
.
from_config
(
btlm_config
,
trust_remote_code
=
True
).
to
(
device
)
assert
model
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
std
()
-
model_ref
.
transformer
.
wte
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
lm_head
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
lm_head
.
weight
.
std
()
-
model_ref
.
lm_head
.
weight
.
std
()).
abs
()
<
1e-4
for
l
in
range
(
config
.
n_layer
):
assert
model
.
transformer
.
layers
[
l
].
mixer
.
Wqkv
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mixer
.
Wqkv
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
attn
.
c_attn
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mixer
.
Wqkv
.
bias
.
abs
().
max
()
==
0.0
assert
model
.
transformer
.
layers
[
l
].
mixer
.
out_proj
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mixer
.
out_proj
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
attn
.
c_proj
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mixer
.
out_proj
.
bias
.
abs
().
max
()
==
0.0
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc1
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mlp
.
fc1
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
mlp
.
c_fc
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc1
.
bias
.
abs
().
max
()
==
0.0
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc2
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mlp
.
fc2
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
mlp
.
c_proj
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc2
.
bias
.
abs
().
max
()
==
0.0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment