Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
11be742a
Commit
11be742a
authored
Jan 07, 2023
by
Tri Dao
Browse files
[Gen] Test generation with rotary embedding
parent
8d9674ed
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
29 deletions
+42
-29
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+5
-3
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+2
-3
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+2
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+33
-21
No files found.
flash_attn/models/gpt.py
View file @
11be742a
...
@@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module):
self
.
config
=
config
self
.
config
=
config
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
model_name
,
config
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
model_name
,
config
,
*
args
,
strict
=
True
,
device
=
None
,
**
kwargs
):
"""
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
"""
"""
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
args
,
device
=
device
,
**
kwargs
)
load_return
=
model
.
load_state_dict
(
load_return
=
model
.
load_state_dict
(
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
),
config
))
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
,
device
=
device
),
config
),
strict
=
strict
)
logger
.
info
(
load_return
)
logger
.
info
(
load_return
)
return
model
return
model
...
...
flash_attn/modules/mha.py
View file @
11be742a
...
@@ -341,7 +341,6 @@ class MHA(nn.Module):
...
@@ -341,7 +341,6 @@ class MHA(nn.Module):
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
groups
=
3
*
embed_dim
)
else
:
else
:
inner_attn_cls
=
inner_cross_attn_cls
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
@@ -482,9 +481,9 @@ class MHA(nn.Module):
...
@@ -482,9 +481,9 @@ class MHA(nn.Module):
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
if
inference_params
is
None
:
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
context
=
self
.
inner_
cross_
attn
(
q
,
kv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_
cross_
attn
,
q
,
kv
,
**
kwargs
)
else
:
else
:
kv
=
self
.
_update_kv_cache
(
kv
)
kv
=
self
.
_update_kv_cache
(
kv
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
False
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
False
)
...
...
flash_attn/utils/pretrained.py
View file @
11be742a
...
@@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME
...
@@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
from
transformers.utils.hub
import
cached_file
def
state_dict_from_pretrained
(
model_name
):
def
state_dict_from_pretrained
(
model_name
,
device
=
None
):
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
))
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
)
,
map_location
=
device
)
tests/models/test_gpt_generation.py
View file @
11be742a
...
@@ -14,32 +14,40 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
...
@@ -14,32 +14,40 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from
flash_attn.utils.generation
import
greedy_decode
from
flash_attn.utils.generation
import
greedy_decode
# TODO: test with rotary embedding
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
optimized
,
fused_ft_kernel
):
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
the HF scores in fp32.
"""
"""
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
'cuda'
rtol
,
atol
=
3e-3
,
3e-1
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
if
optimized
:
if
optimized
:
config
.
use_flash_attn
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
# if not rotary, we load the weight from HF but ignore the position embeddings.
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
)
model
=
model
.
to
(
dtype
=
dtype
)
model
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model
.
eval
()
model_ref
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
model_hf
.
eval
()
...
@@ -47,6 +55,8 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
...
@@ -47,6 +55,8 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
max_length
=
30
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda')
# max_length = 512 + 50
# Slow generation for reference
# Slow generation for reference
sequences
=
[]
sequences
=
[]
...
@@ -66,6 +76,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
...
@@ -66,6 +76,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
fused_ft_kernel
=
fused_ft_kernel
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
return_dict_in_generate
=
True
,
output_scores
=
True
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
@@ -79,6 +90,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
...
@@ -79,6 +90,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
rtol
=
rtol
,
atol
=
atol
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment