Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
11be742a
"docs/source/en/api/models/unet-motion.md" did not exist on "9ced7844daffc1877b0c599a318371cfe09108f5"
Commit
11be742a
authored
Jan 07, 2023
by
Tri Dao
Browse files
[Gen] Test generation with rotary embedding
parent
8d9674ed
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
29 deletions
+42
-29
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+5
-3
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+2
-3
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+2
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+33
-21
No files found.
flash_attn/models/gpt.py
View file @
11be742a
...
...
@@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module):
self
.
config
=
config
@
classmethod
def
from_pretrained
(
cls
,
model_name
,
config
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
model_name
,
config
,
*
args
,
strict
=
True
,
device
=
None
,
**
kwargs
):
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
args
,
device
=
device
,
**
kwargs
)
load_return
=
model
.
load_state_dict
(
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
),
config
))
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
,
device
=
device
),
config
),
strict
=
strict
)
logger
.
info
(
load_return
)
return
model
...
...
flash_attn/modules/mha.py
View file @
11be742a
...
...
@@ -341,7 +341,6 @@ class MHA(nn.Module):
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
else
:
inner_attn_cls
=
inner_cross_attn_cls
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
...
@@ -482,9 +481,9 @@ class MHA(nn.Module):
'b d s -> b s d'
).
contiguous
()
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
context
=
self
.
inner_
cross_
attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_
cross_
attn
,
q
,
kv
,
**
kwargs
)
else
:
kv
=
self
.
_update_kv_cache
(
kv
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
False
)
...
...
flash_attn/utils/pretrained.py
View file @
11be742a
...
...
@@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
def
state_dict_from_pretrained
(
model_name
):
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
))
def
state_dict_from_pretrained
(
model_name
,
device
=
None
):
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
)
,
map_location
=
device
)
tests/models/test_gpt_generation.py
View file @
11be742a
...
...
@@ -14,39 +14,49 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from
flash_attn.utils.generation
import
greedy_decode
# TODO: test with rotary embedding
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
optimized
,
fused_ft_kernel
):
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
if
optimized
:
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
)
model
=
model
.
to
(
dtype
=
dtype
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda')
# max_length = 512 + 50
# Slow generation for reference
sequences
=
[]
...
...
@@ -66,20 +76,22 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment