Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
993d1244
Commit
993d1244
authored
Mar 29, 2023
by
Tri Dao
Browse files
Implement GPT-NeoX
parent
f5d0fbd4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
209 additions
and
9 deletions
+209
-9
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+6
-3
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+107
-0
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+5
-0
tests/models/test_gpt_neox.py
tests/models/test_gpt_neox.py
+84
-0
tests/models/test_gptj.py
tests/models/test_gptj.py
+7
-6
No files found.
flash_attn/models/gpt.py
View file @
993d1244
...
@@ -25,6 +25,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
...
@@ -25,6 +25,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.models.opt
import
remap_state_dict_hf_opt
from
flash_attn.models.opt
import
remap_state_dict_hf_opt
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
from
flash_attn.models.gpt_neox
import
remap_state_dict_hf_gpt_neox
try
:
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
@@ -205,6 +206,8 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -205,6 +206,8 @@ class GPTPreTrainedModel(nn.Module):
elif
model_name
.
startswith
(
'EleutherAI/gpt-j-'
):
elif
model_name
.
startswith
(
'EleutherAI/gpt-j-'
):
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
strict
=
False
# We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
strict
=
False
# We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
elif
model_name
.
startswith
(
'EleutherAI/gpt-neox-'
):
state_dict
=
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
)
else
:
else
:
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
if
world_size
>
1
:
if
world_size
>
1
:
...
@@ -355,6 +358,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -355,6 +358,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
tie_word_embeddings
=
getattr
(
config
,
'tie_word_embeddings'
,
True
)
self
.
tie_word_embeddings
=
getattr
(
config
,
'tie_word_embeddings'
,
True
)
lm_head_bias
=
getattr
(
config
,
'lm_head_bias'
,
False
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
...
@@ -366,13 +370,12 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -366,13 +370,12 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
else
:
else
:
self
.
project_out
=
None
self
.
project_out
=
None
if
process_group
is
None
:
if
process_group
is
None
:
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
not
self
.
tie_word_embeddings
,
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
lm_head_bias
,
**
factory_kwargs
)
**
factory_kwargs
)
else
:
else
:
if
ColumnParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
:
raise
ImportError
(
'fused_dense_lib is not installed'
)
raise
ImportError
(
'fused_dense_lib is not installed'
)
self
.
lm_head
=
ColumnParallelLinear
(
self
.
lm_head
=
ColumnParallelLinear
(
embed_dim
,
vocab_size
,
process_group
,
bias
=
not
self
.
tie_word_embedding
s
,
embed_dim
,
vocab_size
,
process_group
,
bias
=
lm_head_bia
s
,
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
)
)
# Initialize weights and apply final processing
# Initialize weights and apply final processing
...
...
flash_attn/models/gpt_neox.py
0 → 100644
View file @
993d1244
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPTNeoXConfig
def
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^gpt_neox.'
,
'transformer.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.embed_in.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'transformer.embeddings.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'tie_word_embeddings'
):
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
else
:
output_embeddings
=
state_dict
.
pop
(
'embed_out.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'lm_head.weight'
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).input_layernorm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).post_attention_layernorm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_h_to_4h.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_4h_to_h.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
# We don't store these biases
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.masked_bias'
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
Wqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.query_key_value.weight'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
rearrange
(
Wqkv
,
'(nheads three headdim) ... -> (three nheads headdim) ...'
,
three
=
3
,
headdim
=
headdim
)
bqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.query_key_value.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
rearrange
(
bqkv
,
'(nheads three headdim) -> (three nheads headdim)'
,
three
=
3
,
headdim
=
headdim
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention.dense.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention.rotary_emb.'
,
r
'transformer.layers.\1.mixer.rotary_emb.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
gpt_neox_config_to_gpt2_config
(
gpt_neox_config
:
GPTNeoXConfig
)
->
GPT2Config
:
assert
gpt_neox_config
.
rotary_emb_base
==
10000
return
GPT2Config
(
vocab_size
=
gpt_neox_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
gpt_neox_config
.
hidden_size
,
n_layer
=
gpt_neox_config
.
num_hidden_layers
,
n_head
=
gpt_neox_config
.
num_attention_heads
,
n_inner
=
gpt_neox_config
.
intermediate_size
,
activation_function
=
gpt_neox_config
.
hidden_act
,
resid_pdrop
=
0.0
,
# No dropout
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
gpt_neox_config
.
layer_norm_eps
,
initializer_range
=
gpt_neox_config
.
initializer_range
,
bos_token_id
=
gpt_neox_config
.
bos_token_id
,
eos_token_id
=
gpt_neox_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
True
,
parallel_block
=
gpt_neox_config
.
use_parallel_residual
,
parallel_block_tied_norm
=
False
,
rotary_emb_fraction
=
gpt_neox_config
.
rotary_pct
,
tie_word_embeddings
=
gpt_neox_config
.
tie_word_embeddings
,
)
flash_attn/models/gptj.py
View file @
993d1244
...
@@ -34,6 +34,10 @@ def remap_state_dict_hf_gptj(state_dict, config):
...
@@ -34,6 +34,10 @@ def remap_state_dict_hf_gptj(state_dict, config):
state_dict
[
'lm_head.weight'
]
=
F
.
pad
(
state_dict
[
'lm_head.weight'
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
)
output_embeddings_bias
=
state_dict
.
pop
(
'lm_head.bias'
)
state_dict
[
'lm_head.bias'
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
...
@@ -92,4 +96,5 @@ def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
...
@@ -92,4 +96,5 @@ def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
out_proj_bias
=
False
,
lm_head_bias
=
True
,
)
)
tests/models/test_gpt_neox.py
0 → 100644
View file @
993d1244
import
time
import
torch
import
pytest
from
transformers
import
GPTNeoXConfig
,
AutoTokenizer
from
transformers.models.gpt_neox.modeling_gpt_neox
import
GPTNeoXForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt_neox
import
remap_state_dict_hf_gpt_neox
,
gpt_neox_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-neox-20b"
])
def
test_gptj_state_dict
(
model_name
):
config
=
gpt_neox_config_to_gpt2_config
(
GPTNeoXConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_hf_gpt_neox
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-neox-20b"
])
def
test_gpt_neox_optimized
(
model_name
):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
config
=
gpt_neox_config_to_gpt2_config
(
GPTNeoXConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
# GPT-NeoX-20B uses "gelu_fast"
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
# Need at least 2 GPUs, otherwise we'll OOM
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
GPTNeoXForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
'auto'
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
gpt_neox
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
GPTNeoXForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
})
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
gpt_neox
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
assert
(
out
-
out_ref
).
abs
().
mean
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
assert
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
tests/models/test_gptj.py
View file @
993d1244
...
@@ -3,7 +3,7 @@ import re
...
@@ -3,7 +3,7 @@ import re
import
torch
import
torch
import
pytest
import
pytest
from
transformers
import
GPTJConfig
from
transformers
import
GPTJConfig
,
AutoTokenizer
from
transformers.models.gptj.modeling_gptj
import
GPTJForCausalLM
from
transformers.models.gptj.modeling_gptj
import
GPTJForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
...
@@ -37,7 +37,6 @@ def test_gptj_optimized(model_name):
...
@@ -37,7 +37,6 @@ def test_gptj_optimized(model_name):
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
True
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
...
@@ -46,22 +45,24 @@ def test_gptj_optimized(model_name):
...
@@ -46,22 +45,24 @@ def test_gptj_optimized(model_name):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
batch_size
=
2
batch_size
=
2
max_seqlen
=
256
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
logits
=
model
(
input_ids
).
logits
del
model
del
model
model_ref
=
GPTJForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
{
""
:
device
})
model_ref
.
eval
()
model_ref
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
del
model_ref
model_hf
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_hf
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
})
model_hf
.
eval
()
model_hf
.
eval
()
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
logits_hf
=
model_hf
(
input_ids
).
logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment