Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4d87e4d8
Commit
4d87e4d8
authored
Mar 22, 2023
by
Tri Dao
Browse files
Implement GPT-J
parent
4360cfc6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
522 additions
and
87 deletions
+522
-87
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+213
-56
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+95
-0
flash_attn/models/opt.py
flash_attn/models/opt.py
+1
-1
flash_attn/modules/block.py
flash_attn/modules/block.py
+90
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+25
-19
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+11
-4
tests/models/test_gpt.py
tests/models/test_gpt.py
+2
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+2
-2
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+1
-1
tests/models/test_gptj.py
tests/models/test_gptj.py
+80
-0
tests/models/test_opt.py
tests/models/test_opt.py
+2
-2
No files found.
flash_attn/models/gpt.py
View file @
4d87e4d8
...
@@ -18,12 +18,13 @@ from einops import rearrange
...
@@ -18,12 +18,13 @@ from einops import rearrange
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.mlp
import
Mlp
,
FusedMLP
,
ParallelFusedMLP
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.models.opt
import
remap_state_dict_opt
from
flash_attn.models.opt
import
remap_state_dict_hf_opt
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
try
:
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
@@ -36,9 +37,10 @@ except ImportError:
...
@@ -36,9 +37,10 @@ except ImportError:
dropout_add_layer_norm
=
None
dropout_add_layer_norm
=
None
try
:
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
,
sqrelu_fwd
except
ImportError
:
except
ImportError
:
FusedDenseSqreluDense
=
None
FusedDenseSqreluDense
=
None
sqrelu_fwd
=
None
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -54,8 +56,11 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -54,8 +56,11 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
dwconv
=
getattr
(
config
,
'attn_dwconv'
,
False
)
dwconv
=
getattr
(
config
,
'attn_dwconv'
,
False
)
if
dwconv
:
if
dwconv
:
assert
process_group
is
None
,
'TensorParallel MHA does not support dwconv yet'
assert
process_group
is
None
,
'TensorParallel MHA does not support dwconv yet'
qkv_proj_bias
=
getattr
(
config
,
'qkv_proj_bias'
,
True
)
out_proj_bias
=
getattr
(
config
,
'out_proj_bias'
,
True
)
rotary_emb_dim
=
int
(
getattr
(
config
,
'rotary_emb_fraction'
,
0.0
)
*
head_dim
)
rotary_emb_dim
=
int
(
getattr
(
config
,
'rotary_emb_fraction'
,
0.0
)
*
head_dim
)
rotary_emb_scale_base
=
getattr
(
config
,
'rotary_emb_scale_base'
,
0
)
rotary_emb_scale_base
=
getattr
(
config
,
'rotary_emb_scale_base'
,
None
)
rotary_emb_interleaved
=
getattr
(
config
,
'rotary_emb_interleaved'
,
False
)
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
not
fused_bias_fc
:
if
not
fused_bias_fc
:
...
@@ -66,9 +71,12 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -66,9 +71,12 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
parallel_kwargs
=
({
'process_group'
:
process_group
,
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
if
process_group
is
not
None
else
{})
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_interleaved
=
rotary_emb_interleaved
,
use_flash_attn
=
use_flash_attn
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
)
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
)
return
mixer_cls
return
mixer_cls
...
@@ -88,8 +96,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -88,8 +96,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
process_group
is
not
None
:
if
process_group
is
not
None
:
assert
fused_mlp
,
'Tensor Parallel is only implemented for FusedMLP'
assert
fused_mlp
,
'Tensor Parallel is only implemented for FusedMLP'
if
not
fused_mlp
and
not
fused_dense_sqrelu_dense
:
if
not
fused_mlp
and
not
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
if
config
.
activation_function
==
'relu'
:
if
config
.
activation_function
==
'relu'
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
elif
config
.
activation_function
==
'sqrelu'
:
assert
sqrelu_fwd
is
not
None
,
'sqrelu_fwd is not implemented'
activation
=
sqrelu_fwd
else
:
else
:
approximate
=
(
'tanh'
if
config
.
activation_function
approximate
=
(
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'none'
)
...
@@ -132,12 +144,27 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
...
@@ -132,12 +144,27 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
residual_in_fp32
=
getattr
(
config
,
'residual_in_fp32'
,
False
)
residual_in_fp32
=
getattr
(
config
,
'residual_in_fp32'
,
False
)
resid_dropout1
=
config
.
resid_pdrop
if
layer_idx
is
None
or
layer_idx
>
0
else
config
.
embd_pdrop
resid_dropout1
=
config
.
resid_pdrop
if
layer_idx
is
None
or
layer_idx
>
0
else
config
.
embd_pdrop
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
parallel_block
=
getattr
(
config
,
'parallel_block'
,
False
)
prenorm
=
prenorm
,
resid_dropout1
=
resid_dropout1
,
resid_dropout2
=
config
.
resid_pdrop
,
if
not
parallel_block
:
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
block
=
Block
(
residual_in_fp32
=
residual_in_fp32
,
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
prenorm
=
prenorm
,
resid_dropout1
=
resid_dropout1
,
resid_dropout2
=
config
.
resid_pdrop
,
mark_shared_params
=
process_group
is
not
None
)
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
residual_in_fp32
=
residual_in_fp32
,
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
mark_shared_params
=
process_group
is
not
None
)
else
:
assert
prenorm
block
=
ParallelBlock
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
resid_dropout1
=
resid_dropout1
,
resid_dropout2
=
config
.
resid_pdrop
,
tied_norm
=
getattr
(
config
,
'parallel_block_tied_norm'
,
False
),
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
residual_in_fp32
=
residual_in_fp32
,
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
mark_shared_params
=
process_group
is
not
None
)
block
.
layer_idx
=
layer_idx
block
.
layer_idx
=
layer_idx
return
block
return
block
...
@@ -172,9 +199,12 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -172,9 +199,12 @@ class GPTPreTrainedModel(nn.Module):
model_name
,
device
=
'cpu'
,
dtype
=
dtype
model_name
,
device
=
'cpu'
,
dtype
=
dtype
)
)
if
model_name
.
startswith
(
'gpt2'
):
if
model_name
.
startswith
(
'gpt2'
):
state_dict
=
remap_state_dict_gpt2
(
state_dict
,
config
)
state_dict
=
remap_state_dict_
hf_
gpt2
(
state_dict
,
config
)
elif
model_name
.
startswith
(
'facebook/opt'
):
elif
model_name
.
startswith
(
'facebook/opt'
):
state_dict
=
remap_state_dict_opt
(
state_dict
,
config
)
state_dict
=
remap_state_dict_hf_opt
(
state_dict
,
config
)
elif
model_name
.
startswith
(
'EleutherAI/gpt-j-'
):
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
strict
=
False
# We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
else
:
else
:
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
if
world_size
>
1
:
if
world_size
>
1
:
...
@@ -223,6 +253,8 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -223,6 +253,8 @@ class GPTModel(GPTPreTrainedModel):
# These 2 options are for OPT-350m
# These 2 options are for OPT-350m
self
.
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
self
.
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
word_embed_proj_dim
=
getattr
(
config
,
'word_embed_proj_dim'
,
None
)
word_embed_proj_dim
=
getattr
(
config
,
'word_embed_proj_dim'
,
None
)
# For GPT-J, GPT-NeoX
self
.
parallel_block
=
getattr
(
config
,
'parallel_block'
,
False
)
if
process_group
is
None
:
if
process_group
is
None
:
self
.
embeddings
=
GPT2Embeddings
(
self
.
embeddings
=
GPT2Embeddings
(
...
@@ -276,6 +308,8 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -276,6 +308,8 @@ class GPTModel(GPTPreTrainedModel):
embedding_kwargs
=
({
'combine_batch_seqlen_dim'
:
True
}
embedding_kwargs
=
({
'combine_batch_seqlen_dim'
:
True
}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
if
self
.
parallel_block
:
hidden_states2
=
None
residual
=
None
residual
=
None
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
...
@@ -283,15 +317,27 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -283,15 +317,27 @@ class GPTModel(GPTPreTrainedModel):
mixer_kwargs
[
'inference_params'
]
=
inference_params
mixer_kwargs
[
'inference_params'
]
=
inference_params
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
if
self
.
prenorm
:
if
self
.
prenorm
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
if
not
self
.
parallel_block
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
else
:
hidden_states
,
hidden_states2
,
residual
=
layer
(
hidden_states
,
hidden_states2
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
else
:
else
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
self
.
prenorm
:
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_f
(
hidden_states
)
dropped
=
self
.
drop_f
(
hidden_states
)
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
if
not
self
.
parallel_block
:
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
else
:
dropped2
=
self
.
drop_f
(
hidden_states2
)
residual
=
((
residual
+
dropped
+
dropped2
)
if
residual
is
not
None
else
dropped
+
dropped2
)
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
else
:
assert
not
self
.
parallel_block
# Set prenorm=False here since we don't need the residual
# Set prenorm=False here since we don't need the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
...
@@ -308,6 +354,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -308,6 +354,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
tie_word_embeddings
=
getattr
(
config
,
'tie_word_embeddings'
,
True
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
...
@@ -319,12 +366,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -319,12 +366,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
else
:
else
:
self
.
project_out
=
None
self
.
project_out
=
None
if
process_group
is
None
:
if
process_group
is
None
:
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
not
self
.
tie_word_embeddings
,
**
factory_kwargs
)
else
:
else
:
if
ColumnParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
:
raise
ImportError
(
'fused_dense_lib is not installed'
)
raise
ImportError
(
'fused_dense_lib is not installed'
)
self
.
lm_head
=
ColumnParallelLinear
(
self
.
lm_head
=
ColumnParallelLinear
(
embed_dim
,
vocab_size
,
process_group
,
bias
=
False
,
embed_dim
,
vocab_size
,
process_group
,
bias
=
not
self
.
tie_word_embeddings
,
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
)
)
# Initialize weights and apply final processing
# Initialize weights and apply final processing
...
@@ -333,7 +381,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -333,7 +381,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
if
self
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_shared_params
(
self
,
self
.
process_group
)
sync_shared_params
(
self
,
self
.
process_group
)
...
@@ -381,7 +430,95 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -381,7 +430,95 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
remap_state_dict_gpt2
(
state_dict
,
config
):
def
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
assert
vocab_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
def
shard_first_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
-
1
]
//
world_size
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_qkv_headdim
(
state_dict
,
key
):
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
dim
=
x
.
shape
[
1
]
//
world_size
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
'three d ... -> (three d) ...'
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
shard_first_dim
(
state_dict
,
'lm_head.weight'
)
if
'transformer.embeddings.position_embeddings.weight'
in
state_dict
:
shard_last_dim
(
state_dict
,
'transformer.embeddings.position_embeddings.weight'
)
for
i
in
range
(
config
.
num_hidden_layers
):
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.weight'
)
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
)
return
state_dict
def
combine_state_dicts_tp
(
state_dicts
,
config
):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
world_size
=
len
(
state_dicts
)
keys
=
state_dicts
[
0
].
keys
()
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
assert
vocab_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
# The word embeddings from Megatron are weird, for each shard only the first
# vocab_size // world_size coordinates are nonzero.
def
combine_word_embeddings
(
state_dicts
,
state_dict
,
key
):
assert
all
(
s
[
key
].
shape
[
0
]
==
vocab_size
for
s
in
state_dicts
)
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
][:
vocab_size
//
world_size
]
for
s
in
state_dicts
],
dim
=
0
)
def
combine_dim
(
state_dicts
,
state_dict
,
key
,
dim
=-
1
):
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
xs
=
[
rearrange
(
s
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'three d ... -> (three d) ...'
)
state_dict
=
state_dicts
[
0
].
copy
()
# don't modify state_dict[0] inplace
combine_word_embeddings
(
state_dicts
,
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
combine_word_embeddings
(
state_dicts
,
state_dict
,
'lm_head.weight'
)
if
'transformer.embeddings.position_embeddings.weight'
in
state_dict
:
combine_dim
(
state_dicts
,
state_dict
,
'transformer.embeddings.position_embeddings.weight'
,
-
1
)
for
i
in
range
(
config
.
num_hidden_layers
):
combine_qkv_headdim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.weight'
)
combine_qkv_headdim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
,
-
1
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
,
0
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
,
0
)
combine_dim
(
state_dicts
,
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
,
-
1
)
return
state_dict
def
remap_state_dict_hf_gpt2
(
state_dict
,
config
):
# Word embedding and position embedding
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
'^wpe.'
,
'transformer.embeddings.position_embeddings.'
,
key
)
return
re
.
sub
(
r
'^wpe.'
,
'transformer.embeddings.position_embeddings.'
,
key
)
...
@@ -430,47 +567,67 @@ def remap_state_dict_gpt2(state_dict, config):
...
@@ -430,47 +567,67 @@ def remap_state_dict_gpt2(state_dict, config):
return
state_dict
return
state_dict
def
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
):
def
remap_state_dict_megatron
(
state_dict
,
config
):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
def
key_mapping_transformer
(
key
):
with tensor parallel.
key
=
re
.
sub
(
r
'^language_model.encoder.'
,
'transformer.'
,
key
)
"""
key
=
re
.
sub
(
r
'^language_model.'
,
'transformer.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_transformer
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
'^wpe.'
,
'transformer.embeddings.position_embeddings.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'transformer.embedding.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
assert
vocab_size
%
wor
l
d_
size
==
0
state_dict
[
'transformer.embeddings.
word_
embeddings.weight'
]
=
F
.
pad
(
assert
config
.
hidden
_size
%
wor
l
d_
size
==
0
word_embeddings
,
(
0
,
0
,
0
,
vocab
_size
-
word_
embeddings
.
shape
[
0
])
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
)
assert
inner_dim
%
wor
l
d_
size
==
0
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.
word_
embeddings.weight'
]
def
shard_first_dim
(
state_dict
,
key
):
# LayerNorm
x
=
state_dict
[
key
]
def
key_mapping_ln
(
key
):
dim
=
x
.
shape
[
0
]
//
world_size
key
=
re
.
sub
(
r
'^transformer.final_layernorm.(weight|bias)'
,
r
'transformer.ln_f.\1'
,
key
)
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).input_layernorm.(weight|bias)'
,
r
'transformer.layers.\1.norm1.\2'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)'
,
r
'transformer.layers.\1.norm2.\2'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
shard_last_dim
(
state_dict
,
key
):
# MLP
x
=
state_dict
[
key
]
def
key_mapping_mlp
(
key
):
dim
=
x
.
shape
[
-
1
]
//
world_size
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)'
,
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
r
'transformer.layers.\1.mlp.fc1.\2'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)'
,
r
'transformer.layers.\1.mlp.fc2.\2'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
shard_qkv_headdim
(
state_dict
,
key
):
# Attention
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
def
key_mapping_attn
(
key
):
dim
=
x
.
shape
[
1
]
//
world_size
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq'
,
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
r
'transformer.layers.\1.mixer.rotary_emb.inv_freq'
,
key
)
'three d ... -> (three d) ...'
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)'
,
r
'transformer.layers.\1.mixer.Wqkv.\2'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.dense.(weight|bias)'
,
r
'transformer.layers.\1.mixer.out_proj.\2'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
for
d
in
range
(
config
.
num_hidden_layers
):
Wqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
d
}
.mixer.Wqkv.weight'
)
state_dict
[
f
'transformer.layers.
{
d
}
.mixer.Wqkv.weight'
]
=
rearrange
(
Wqkv
,
'(nheads three headdim) ... -> (three nheads headdim) ...'
,
three
=
3
,
headdim
=
headdim
)
bqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
d
}
.mixer.Wqkv.bias'
)
state_dict
[
f
'transformer.layers.
{
d
}
.mixer.Wqkv.bias'
]
=
rearrange
(
bqkv
,
'(nheads three headdim) -> (three nheads headdim)'
,
three
=
3
,
headdim
=
headdim
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
shard_first_dim
(
state_dict
,
'lm_head.weight'
)
if
'transformer.embeddings.position_embeddings.weight'
in
state_dict
:
shard_last_dim
(
state_dict
,
'transformer.embeddings.position_embeddings.weight'
)
for
i
in
range
(
config
.
num_hidden_layers
):
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.weight'
)
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
)
return
state_dict
return
state_dict
flash_attn/models/gptj.py
0 → 100644
View file @
4d87e4d8
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
GPTJConfig
def
remap_state_dict_hf_gptj
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^transformer.h.'
,
'transformer.layers.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.wte.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'transformer.embeddings.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'tie_word_embeddings'
):
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
else
:
output_embeddings
=
state_dict
.
pop
(
'lm_head.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'lm_head.weight'
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).ln_1.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_in.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_out.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.q_proj.weight'
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.k_proj.weight'
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.v_proj.weight'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these biases
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.masked_bias'
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attn.out_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
gptj_config_to_gpt2_config
(
gptj_config
:
GPTJConfig
)
->
GPT2Config
:
headdim
=
gptj_config
.
n_embd
//
gptj_config
.
n_head
return
GPT2Config
(
vocab_size
=
gptj_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
gptj_config
.
n_embd
,
n_layer
=
gptj_config
.
n_layer
,
n_head
=
gptj_config
.
n_head
,
n_inner
=
gptj_config
.
n_inner
,
activation_function
=
gptj_config
.
activation_function
,
resid_pdrop
=
gptj_config
.
resid_pdrop
,
embd_pdrop
=
gptj_config
.
embd_pdrop
,
attn_pdrop
=
gptj_config
.
attn_pdrop
,
layer_norm_epsilon
=
gptj_config
.
layer_norm_epsilon
,
initializer_range
=
gptj_config
.
initializer_range
,
bos_token_id
=
gptj_config
.
bos_token_id
,
eos_token_id
=
gptj_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
True
,
parallel_block
=
True
,
parallel_block_tied_norm
=
True
,
rotary_emb_fraction
=
gptj_config
.
rotary_dim
/
headdim
,
rotary_emb_interleaved
=
True
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
)
flash_attn/models/opt.py
View file @
4d87e4d8
...
@@ -11,7 +11,7 @@ import torch.nn.functional as F
...
@@ -11,7 +11,7 @@ import torch.nn.functional as F
from
transformers
import
GPT2Config
,
OPTConfig
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_opt
(
state_dict
,
config
):
def
remap_state_dict_
hf_
opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
'^model.decoder.'
,
'transformer.'
,
key
)
key
=
re
.
sub
(
r
'^model.decoder.'
,
'transformer.'
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
...
...
flash_attn/modules/block.py
View file @
4d87e4d8
...
@@ -190,3 +190,93 @@ class Block(nn.Module):
...
@@ -190,3 +190,93 @@ class Block(nn.Module):
rowscale
=
rowscale2
,
prenorm
=
False
rowscale
=
rowscale2
,
prenorm
=
False
)
)
return
hidden_states
return
hidden_states
class
ParallelBlock
(
nn
.
Module
):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
resid_dropout1
=
0.
,
resid_dropout2
=
0.
,
tied_norm
=
False
,
fused_dropout_add_ln
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super
().
__init__
()
self
.
tied_norm
=
tied_norm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
assert
not
self
.
fused_dropout_add_ln
,
'This is not implemented for ParallelBlock yet'
self
.
residual_in_fp32
=
residual_in_fp32
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
if
not
self
.
tied_norm
:
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_add_ln is not installed'
assert
isinstance
(
self
.
norm1
,
nn
.
LayerNorm
)
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
'norm2'
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
if
hasattr
(
self
,
'norm2'
):
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
def
forward
(
self
,
hidden_states1
:
Tensor
,
hidden_states2
:
Optional
[
Tensor
]
=
None
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
((
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
hidden_states1
=
self
.
mixer
(
hidden_states1
,
**
mixer_kwargs
)
hidden_states2
=
self
.
mlp
(
hidden_states2
)
return
hidden_states1
,
hidden_states2
,
residual
flash_attn/modules/mha.py
View file @
4d87e4d8
...
@@ -347,9 +347,10 @@ class MHA(nn.Module):
...
@@ -347,9 +347,10 @@ class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention
"""
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
bias
=
True
,
dropout
=
0.0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
rotary_emb_scale_base
=
0
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
"""
"""
...
@@ -377,7 +378,7 @@ class MHA(nn.Module):
...
@@ -377,7 +378,7 @@ class MHA(nn.Module):
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
device
=
device
)
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
...
@@ -388,29 +389,32 @@ class MHA(nn.Module):
...
@@ -388,29 +389,32 @@ class MHA(nn.Module):
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
else
:
self
.
Wqkv
=
linear_resid_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_resid_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
if
self
.
dwconv
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
groups
=
3
*
embed_dim
)
else
:
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
qkv_proj_
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
else
:
self
.
Wkv
=
linear_resid_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_resid_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
if
self
.
dwconv
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
)
groups
=
embed_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
2
*
embed_dim
,
2
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_kv
=
nn
.
Conv1d
(
2
*
embed_dim
,
2
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
2
*
embed_dim
)
groups
=
2
*
embed_dim
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
...
@@ -526,9 +530,10 @@ class ParallelMHA(nn.Module):
...
@@ -526,9 +530,10 @@ class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention
"""
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
bias
=
True
,
dropout
=
0.0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_scale_base
=
0
,
use_flash_attn
=
False
,
checkpointing
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
...
@@ -546,11 +551,12 @@ class ParallelMHA(nn.Module):
...
@@ -546,11 +551,12 @@ class ParallelMHA(nn.Module):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
device
=
device
)
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
...
@@ -558,8 +564,8 @@ class ParallelMHA(nn.Module):
...
@@ -558,8 +564,8 @@ class ParallelMHA(nn.Module):
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
bias
=
out_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
...
...
flash_attn/utils/generation.py
View file @
4d87e4d8
...
@@ -71,8 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...
@@ -71,8 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
eos_token_id
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
eos_token_id
=
None
,
teacher_outputs
=
None
,
vocab_size
=
None
,
tensor_parallel
=
1
,
cg
=
False
,
timing
=
False
):
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
"""Decoding, either greedy or with top-k or top-p sampling.
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
@@ -87,6 +87,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -87,6 +87,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores: tuples of (batch, vocab_size)
scores: tuples of (batch, vocab_size)
"""
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
teacher_output_len
=
teacher_outputs
.
shape
[
1
]
if
teacher_outputs
is
not
None
else
0
if
cg
:
if
cg
:
assert
fused_ft_kernel
assert
fused_ft_kernel
if
not
hasattr
(
model
,
'_decoding_cache'
):
if
not
hasattr
(
model
,
'_decoding_cache'
):
...
@@ -111,7 +112,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -111,7 +112,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
vocab_size
is
not
None
:
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
if
teacher_outputs
is
None
or
teacher_output_len
<=
seqlen_og
:
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
seqlen_og
]
sequences
=
[
next_token
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
while
True
:
...
@@ -126,7 +130,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -126,7 +130,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
vocab_size
is
not
None
:
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
:
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
+
1
]
sequences
.
append
(
next_token
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
inference_params
.
sequence_len_offset
+=
1
if
eos_token_id
is
not
None
and
(
next_token
==
eos_token_id
).
all
():
if
eos_token_id
is
not
None
and
(
next_token
==
eos_token_id
).
all
():
...
...
tests/models/test_gpt.py
View file @
4d87e4d8
...
@@ -7,7 +7,7 @@ from transformers import GPT2Config
...
@@ -7,7 +7,7 @@ from transformers import GPT2Config
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_
hf_
gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
...
@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
...
@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["gpt2"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_state_dict
(
model_name
):
def
test_gpt2_state_dict
(
model_name
):
config
=
GPT2Config
.
from_pretrained
(
model_name
)
config
=
GPT2Config
.
from_pretrained
(
model_name
)
pretrained_state_dict
=
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_
hf_
gpt2
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
...
...
tests/models/test_gpt_generation.py
View file @
4d87e4d8
...
@@ -12,8 +12,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
...
@@ -12,8 +12,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_
hf_
gpt2
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.models.opt
import
remap_state_dict_
hf_
opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.generation
import
update_graph_cache
from
flash_attn.utils.generation
import
update_graph_cache
...
...
tests/models/test_gpt_generation_parallel.py
View file @
4d87e4d8
...
@@ -12,7 +12,7 @@ from transformers import GPT2Config, GPT2Tokenizer
...
@@ -12,7 +12,7 @@ from transformers import GPT2Config, GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_
hf_
gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.distributed
import
all_gather_raw
...
...
tests/models/test_gptj.py
0 → 100644
View file @
4d87e4d8
import
re
import
torch
import
pytest
from
transformers
import
GPTJConfig
from
transformers.models.gptj.modeling_gptj
import
GPTJForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
,
gptj_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_state_dict
(
model_name
):
config
=
gptj_config_to_gpt2_config
(
GPTJConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_hf_gptj
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
'meta'
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
rotary_inv_freq_keys
=
{
f
'transformer.layers.
{
l
}
.mixer.rotary_emb.inv_freq'
for
l
in
range
(
config
.
n_layer
)}
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
|
rotary_inv_freq_keys
for
k
in
state_dict
.
keys
()
-
rotary_inv_freq_keys
:
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"EleutherAI/gpt-j-6B"
])
def
test_gptj_optimized
(
model_name
):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
config
=
gptj_config_to_gpt2_config
(
GPTJConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
False
# FlashAttention doesn't support hdim 256 yet
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
False
# We don't support parallel block yet
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
model_ref
=
GPTJForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
logits_ref
=
model_ref
(
input_ids
).
logits
del
model_ref
model_hf
=
GPTJForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model_hf
.
eval
()
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
tests/models/test_opt.py
View file @
4d87e4d8
...
@@ -7,7 +7,7 @@ from transformers import OPTConfig
...
@@ -7,7 +7,7 @@ from transformers import OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.models.opt
import
remap_state_dict_
hf_
opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
...
@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
...
@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_state_dict
(
model_name
):
def
test_opt_state_dict
(
model_name
):
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_opt
(
state_dict_from_pretrained
(
model_name
),
config
)
pretrained_state_dict
=
remap_state_dict_
hf_
opt
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment