Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
7ffba9a5
Commit
7ffba9a5
authored
Dec 24, 2023
by
Tri Dao
Browse files
Implement BTLM model
parent
2e29dacf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
316 additions
and
0 deletions
+316
-0
flash_attn/models/btlm.py
flash_attn/models/btlm.py
+102
-0
tests/models/test_btlm.py
tests/models/test_btlm.py
+214
-0
No files found.
flash_attn/models/btlm.py
0 → 100644
View file @
7ffba9a5
# Copyright (c) 2023, Tri Dao.
import
math
import
json
import
re
from
pathlib
import
Path
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
AutoConfig
,
PretrainedConfig
def
remap_state_dict_hf_btlm
(
state_dict
,
config
):
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
"^transformer.wpe."
,
"transformer.embeddings.position_embeddings."
,
key
)
if
"transformer.wpe.weight"
in
state_dict
:
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.wte.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.ln_f.(weight|bias)"
,
r
"transformer.ln_f.\1"
,
key
)
key
=
re
.
sub
(
r
"^transformer.h.(\d+).ln_(1|2).(weight|bias)"
,
r
"transformer.layers.\1.norm\2.\3"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
d
in
range
(
config
.
num_hidden_layers
):
W1
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc.weight"
)
W3
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc2.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
W1
.
t
(),
W3
.
t
()],
dim
=
0
)
b1
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc.bias"
)
b3
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_fc2.bias"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc1.bias"
]
=
torch
.
cat
([
b1
,
b3
],
dim
=
0
)
W2
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.mlp.c_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mlp.fc2.weight"
]
=
W2
.
t
()
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"^transformer.h.(\d+).mlp.c_proj.bias"
,
r
"transformer.layers.\1.mlp.fc2.bias"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
d
in
range
(
config
.
num_hidden_layers
):
Wqkv
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.attn.c_attn.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.Wqkv.weight"
]
=
Wqkv
.
t
()
Wout
=
state_dict
.
pop
(
f
"transformer.h.
{
d
}
.attn.c_proj.weight"
)
state_dict
[
f
"transformer.layers.
{
d
}
.mixer.out_proj.weight"
]
=
Wout
.
t
()
state_dict
.
pop
(
f
"transformer.relative_pe.slopes"
)
# We don't store the Alibi slopes
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^transformer.h.(\d+).attn.c_attn.bias"
,
r
"transformer.layers.\1.mixer.Wqkv.bias"
,
key
)
key
=
re
.
sub
(
r
"^transformer.h.(\d+).attn.c_proj.bias"
,
r
"transformer.layers.\1.mixer.out_proj.bias"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
btlm_config_to_gpt2_config
(
btlm_config
:
PretrainedConfig
)
->
GPT2Config
:
return
GPT2Config
(
vocab_size
=
btlm_config
.
vocab_size
,
n_positions
=
0
if
btlm_config
.
position_embedding_type
==
"alibi"
else
btlm_config
.
n_positions
,
n_embd
=
btlm_config
.
hidden_size
,
n_layer
=
btlm_config
.
num_hidden_layers
,
n_head
=
btlm_config
.
num_attention_heads
,
n_inner
=
btlm_config
.
n_inner
,
activation_function
=
btlm_config
.
activation_function
,
resid_pdrop
=
btlm_config
.
resid_pdrop
,
embd_pdrop
=
btlm_config
.
embd_pdrop
,
attn_pdrop
=
btlm_config
.
attn_pdrop
,
layer_norm_epsilon
=
btlm_config
.
layer_norm_epsilon
,
initializer_range
=
btlm_config
.
initializer_range
,
bos_token_id
=
btlm_config
.
bos_token_id
,
eos_token_id
=
btlm_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
use_alibi
=
btlm_config
.
position_embedding_type
==
"alibi"
,
use_flash_attn
=
btlm_config
.
position_embedding_type
==
"alibi"
,
# Alibi code path requires flash_attn
mup_width_scale
=
btlm_config
.
mup_width_scale
,
mup_embeddings_multiplier
=
btlm_config
.
mup_embeddings_scale
,
mup_output_multiplier
=
btlm_config
.
mup_output_alpha
,
mup_scale_qk_dot_by_d
=
btlm_config
.
mup_scale_qk_dot_by_d
,
mlp_multiple_of
=
1
,
)
tests/models/test_btlm.py
0 → 100644
View file @
7ffba9a5
# Copyright (c) 2023, Tri Dao.
import
os
import
time
from
pathlib
import
Path
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.btlm
import
btlm_config_to_gpt2_config
,
remap_state_dict_hf_btlm
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"cerebras/btlm-3b-8k-base"
])
def
test_btlm_state_dict
(
model_name
):
config
=
btlm_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"cerebras/btlm-3b-8k-base"
])
def
test_btlm_optimized
(
model_name
):
"""Check that our implementation of Btlm (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
"cuda"
config
=
btlm_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
config
.
fused_bias_fc
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"cerebras/btlm-3b-8k-base"
])
def
test_btlm_generation
(
model_name
):
dtype
=
torch
.
float16
device
=
"cuda"
config
=
btlm_config_to_gpt2_config
(
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
)
config
.
fused_bias_fc
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
2048
max_length
=
2048
+
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
del
model_ref
pretrained_state_dict
=
remap_state_dict_hf_btlm
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
model
(
input_ids
)
# Warm up
print
(
"Without CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
enable_timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
with
torch
.
no_grad
():
logits_parallel
=
model
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
]
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment