Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
73df3be7
Commit
73df3be7
authored
Dec 25, 2023
by
Tri Dao
Browse files
Add test for BTLM init
parent
7ffba9a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
14 deletions
+47
-14
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+3
-1
tests/models/test_btlm.py
tests/models/test_btlm.py
+44
-13
No files found.
flash_attn/models/gpt.py
View file @
73df3be7
...
...
@@ -396,7 +396,9 @@ def _init_weights(
mup_init_scale
=
math
.
sqrt
(
mup_width_scale
)
if
isinstance
(
module
,
nn
.
Linear
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
*
mup_init_scale
)
module
.
weight
.
_optim
=
{
"lr_multiplier"
:
mup_width_scale
}
optim_cfg
=
getattr
(
module
.
weight
,
"_optim"
,
{})
optim_cfg
.
update
({
"lr_multiplier"
:
mup_width_scale
})
setattr
(
module
.
weight
,
"_optim"
,
optim_cfg
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
...
...
tests/models/test_btlm.py
View file @
73df3be7
# Copyright (c) 2023, Tri Dao.
import
os
import
time
from
pathlib
import
Path
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
...
...
@@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
config
=
btlm_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
...
...
@@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
...
...
@@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
del
model_ref
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
...
...
@@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
assert
torch
.
equal
(
logits_cg
,
logits
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"cerebras/btlm-3b-8k-base"
])
def
test_btlm_init
(
model_name
):
dtype
=
torch
.
float32
device
=
"cuda"
btlm_config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
btlm_config_to_gpt2_config
(
btlm_config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model_ref
=
AutoModelForCausalLM
.
from_config
(
btlm_config
,
trust_remote_code
=
True
).
to
(
device
)
assert
model
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
std
()
-
model_ref
.
transformer
.
wte
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
lm_head
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
lm_head
.
weight
.
std
()
-
model_ref
.
lm_head
.
weight
.
std
()).
abs
()
<
1e-4
for
l
in
range
(
config
.
n_layer
):
assert
model
.
transformer
.
layers
[
l
].
mixer
.
Wqkv
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mixer
.
Wqkv
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
attn
.
c_attn
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mixer
.
Wqkv
.
bias
.
abs
().
max
()
==
0.0
assert
model
.
transformer
.
layers
[
l
].
mixer
.
out_proj
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mixer
.
out_proj
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
attn
.
c_proj
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mixer
.
out_proj
.
bias
.
abs
().
max
()
==
0.0
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc1
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mlp
.
fc1
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
mlp
.
c_fc
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc1
.
bias
.
abs
().
max
()
==
0.0
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc2
.
weight
.
mean
().
abs
()
<
1e-4
assert
(
model
.
transformer
.
layers
[
l
].
mlp
.
fc2
.
weight
.
std
()
-
model_ref
.
transformer
.
h
[
l
].
mlp
.
c_proj
.
weight
.
std
()
).
abs
()
<
1e-4
assert
model
.
transformer
.
layers
[
l
].
mlp
.
fc2
.
bias
.
abs
().
max
()
==
0.0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment