Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4d87e4d8
Commit
4d87e4d8
authored
Mar 22, 2023
by
Tri Dao
Browse files
Implement GPT-J
parent
4360cfc6
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
522 additions
and
87 deletions
+522
-87
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+213
-56
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+95
-0
flash_attn/models/opt.py
flash_attn/models/opt.py
+1
-1
flash_attn/modules/block.py
flash_attn/modules/block.py
+90
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+25
-19
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+11
-4
tests/models/test_gpt.py
tests/models/test_gpt.py
+2
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+2
-2
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+1
-1
tests/models/test_gptj.py
tests/models/test_gptj.py
+80
-0
tests/models/test_opt.py
tests/models/test_opt.py
+2
-2
No files found.
flash_attn/models/gpt.py
View file @
4d87e4d8
This diff is collapsed.
Click to expand it.
flash_attn/models/gptj.py
0 → 100644
View file @
4d87e4d8
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
GPTJConfig
def
remap_state_dict_hf_gptj
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^transformer.h.'
,
'transformer.layers.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.wte.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'transformer.embeddings.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'tie_word_embeddings'
):
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
else
:
output_embeddings
=
state_dict
.
pop
(
'lm_head.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'lm_head.weight'
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).ln_1.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_in.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_out.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.q_proj.weight'
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.k_proj.weight'
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.v_proj.weight'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these biases
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.masked_bias'
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attn.out_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
gptj_config_to_gpt2_config
(
gptj_config
:
GPTJConfig
)
->
GPT2Config
:
headdim
=
gptj_config
.
n_embd
//
gptj_config
.
n_head
return
GPT2Config
(
vocab_size
=
gptj_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
gptj_config
.
n_embd
,
n_layer
=
gptj_config
.
n_layer
,
n_head
=
gptj_config
.
n_head
,
n_inner
=
gptj_config
.
n_inner
,
activation_function
=
gptj_config
.
activation_function
,
resid_pdrop
=
gptj_config
.
resid_pdrop
,
embd_pdrop
=
gptj_config
.
embd_pdrop
,
attn_pdrop
=
gptj_config
.
attn_pdrop
,
layer_norm_epsilon
=
gptj_config
.
layer_norm_epsilon
,
initializer_range
=
gptj_config
.
initializer_range
,
bos_token_id
=
gptj_config
.
bos_token_id
,
eos_token_id
=
gptj_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
True
,
parallel_block
=
True
,
parallel_block_tied_norm
=
True
,
rotary_emb_fraction
=
gptj_config
.
rotary_dim
/
headdim
,
rotary_emb_interleaved
=
True
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
)
flash_attn/models/opt.py
View file @
4d87e4d8
...
...
@@ -11,7 +11,7 @@ import torch.nn.functional as F
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_opt
(
state_dict
,
config
):
def
remap_state_dict_
hf_
opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
'^model.decoder.'
,
'transformer.'
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
...
...
flash_attn/modules/block.py
View file @
4d87e4d8
...
...
@@ -190,3 +190,93 @@ class Block(nn.Module):
rowscale
=
rowscale2
,
prenorm
=
False
)
return
hidden_states
class
ParallelBlock
(
nn
.
Module
):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
resid_dropout1
=
0.
,
resid_dropout2
=
0.
,
tied_norm
=
False
,
fused_dropout_add_ln
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super
().
__init__
()
self
.
tied_norm
=
tied_norm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
assert
not
self
.
fused_dropout_add_ln
,
'This is not implemented for ParallelBlock yet'
self
.
residual_in_fp32
=
residual_in_fp32
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
if
not
self
.
tied_norm
:
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_add_ln is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
'norm2'
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
'norm2'
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
forward
(
self
,
hidden_states1
:
Tensor
,
hidden_states2
:
Optional
[
Tensor
]
=
None
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
((
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
hidden_states2
=
self
.
mlp
(
hidden_states2
)
return
hidden_states1
,
hidden_states2
,
residual
flash_attn/modules/mha.py
View file @
4d87e4d8
...
...
@@ -347,9 +347,10 @@ class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
"""
...
...
@@ -377,7 +378,7 @@ class MHA(nn.Module):
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
device
=
device
)
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
...
...
@@ -388,18 +389,22 @@ class MHA(nn.Module):
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
self
.
Wqkv
=
linear_resid_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_resid_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
qkv_proj_
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
self
.
Wkv
=
linear_resid_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_resid_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
)
...
...
@@ -409,8 +414,7 @@ class MHA(nn.Module):
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
...
...
@@ -526,9 +530,10 @@ class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
use_flash_attn
=
False
,
checkpointing
=
False
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
...
...
@@ -546,11 +551,12 @@ class ParallelMHA(nn.Module):
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
device
=
device
)
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
...
...
@@ -558,8 +564,8 @@ class ParallelMHA(nn.Module):
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
bias
=
out_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
...
...
flash_attn/utils/generation.py
View file @
4d87e4d8
...
...
@@ -71,8 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
...
@@ -87,6 +87,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
assert
fused_ft_kernel
if
not
hasattr
(
model
,
'_decoding_cache'
):
...
...
@@ -111,7 +112,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
if
teacher_outputs
is
None
or
teacher_output_len
<=
seqlen_og
:
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
seqlen_og
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
...
...
@@ -126,7 +130,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
:
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
+
1
]
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
if
eos_token_id
is
not
None
and
(
next_token
==
eos_token_id
).
all
():
...
...
tests/models/test_gpt.py
View file @
4d87e4d8
...
...
@@ -7,7 +7,7 @@ from transformers import GPT2Config
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_
hf_
gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
...
...
@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_state_dict
(
model_name
):
config
=
GPT2Config
.
from_pretrained
(
model_name
)
pretrained_state_dict
=
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_
hf_
gpt2
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
...
...
tests/models/test_gpt_generation.py
View file @
4d87e4d8
...
...
@@ -12,8 +12,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.models.gpt
import
remap_state_dict_
hf_
gpt2
from
flash_attn.models.opt
import
remap_state_dict_
hf_
opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.generation
import
update_graph_cache
...
...
tests/models/test_gpt_generation_parallel.py
View file @
4d87e4d8
...
...
@@ -12,7 +12,7 @@ from transformers import GPT2Config, GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_
hf_
gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
...
...
tests/models/test_gptj.py
0 → 100644
View file @
4d87e4d8
import
re
import
torch
import
pytest
from
transformers
import
GPTJConfig
from
transformers.models.gptj.modeling_gptj
import
GPTJForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
,
gptj_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_state_dict
(
model_name
):
config
=
gptj_config_to_gpt2_config
(
GPTJConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_hf_gptj
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
rotary_inv_freq_keys
=
{
f
'transformer.layers.
{
l
}
.mixer.rotary_emb.inv_freq'
for
l
in
range
(
config
.
n_layer
)}
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
|
rotary_inv_freq_keys
for
k
in
state_dict
.
keys
()
-
rotary_inv_freq_keys
:
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_optimized
(
model_name
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
config
=
gptj_config_to_gpt2_config
(
GPTJConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
False
# FlashAttention doesn't support hdim 256 yet
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
model_ref
=
GPTJForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
model_hf
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_hf
.
eval
()
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
tests/models/test_opt.py
View file @
4d87e4d8
...
...
@@ -7,7 +7,7 @@ from transformers import OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.models.opt
import
remap_state_dict_
hf_
opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
...
...
@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_state_dict
(
model_name
):
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_opt
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_
hf_
opt
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment