Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
63670fd8
Commit
63670fd8
authored
Dec 27, 2022
by
Tri Dao
Browse files
Implement generation for GPT
parent
9d797d88
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
242 additions
and
39 deletions
+242
-39
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+13
-5
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+82
-24
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+64
-0
tests/models/test_gpt.py
tests/models/test_gpt.py
+0
-10
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+83
-0
No files found.
flash_attn/models/gpt.py
View file @
63670fd8
...
...
@@ -20,6 +20,7 @@ from flash_attn.modules.block import Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_sequence_parallel_params
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
...
@@ -61,7 +62,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
if
process_group
is
None
else
{})
parallel_kwargs
=
{
'process_group'
:
process_group
}
if
process_group
is
not
None
else
{}
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
)
...
...
@@ -220,7 +221,7 @@ class GPTModel(GPTPreTrainedModel):
if
self
.
process_group
is
not
None
:
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
...
...
@@ -238,12 +239,14 @@ class GPTModel(GPTPreTrainedModel):
residual_in_fp32
=
True
)
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
else
{})
if
inference_params
is
not
None
:
mixer_kwargs
[
'inference_params'
]
=
inference_params
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
return
hidden_states
class
GPTLMHeadModel
(
GPTPreTrainedModel
):
class
GPTLMHeadModel
(
GPTPreTrainedModel
,
GenerationMixin
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
...
...
@@ -267,8 +270,13 @@ class GPTLMHeadModel(GPTPreTrainedModel):
def
tie_weights
(
self
):
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
...
...
flash_attn/modules/mha.py
View file @
63670fd8
...
...
@@ -53,7 +53,7 @@ class FlashSelfAttention(nn.Module):
self
.
dropout_p
=
attention_dropout
self
.
triton
=
triton
def
forward
(
self
,
qkv
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
...
...
@@ -61,6 +61,7 @@ class FlashSelfAttention(nn.Module):
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
...
...
@@ -71,6 +72,7 @@ class FlashSelfAttention(nn.Module):
"""
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
...
...
@@ -78,13 +80,13 @@ class FlashSelfAttention(nn.Module):
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
else
:
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
# Triton version doesn't support dropout
if
self
.
triton
and
(
self
.
dropout_p
==
0
or
not
self
.
training
):
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
causal
,
self
.
softmax_scale
)
else
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
max_seqlen
=
seqlen
...
...
@@ -92,7 +94,7 @@ class FlashSelfAttention(nn.Module):
device
=
qkv
.
device
)
output
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
...
...
@@ -120,12 +122,14 @@ class FlashCrossAttention(nn.Module):
self
.
dropout_p
=
attention_dropout
self
.
triton
=
triton
def
forward
(
self
,
q
,
kv
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
):
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
...
...
@@ -135,6 +139,7 @@ class FlashCrossAttention(nn.Module):
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
and
kv
.
is_cuda
causal
=
self
.
causal
if
causal
is
None
else
causal
unpadded
=
cu_seqlens
is
not
None
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
...
...
@@ -147,14 +152,14 @@ class FlashCrossAttention(nn.Module):
return
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
self
.
triton
and
(
self
.
dropout_p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
causal
,
self
.
softmax_scale
)
else
:
q
=
rearrange
(
q
,
'b s ... -> (b s) ...'
)
kv
=
rearrange
(
kv
,
'b s ... -> (b s) ...'
)
...
...
@@ -165,7 +170,7 @@ class FlashCrossAttention(nn.Module):
output
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
...
...
@@ -187,15 +192,17 @@ class SelfAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
,
key_padding_mask
=
None
):
def
forward
(
self
,
qkv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
...
...
@@ -205,7 +212,7 @@ class SelfAttention(nn.Module):
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
if
self
.
causal
:
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
...
...
@@ -233,16 +240,18 @@ class CrossAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
kv
,
key_padding_mask
=
None
):
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
k
,
v
=
kv
.
unbind
(
dim
=
2
)
...
...
@@ -254,7 +263,7 @@ class CrossAttention(nn.Module):
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
if
self
.
causal
:
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen_q
,
seqlen_k
),
-
10000.0
,
...
...
@@ -280,7 +289,7 @@ class MHA(nn.Module):
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
...
...
@@ -294,6 +303,7 @@ class MHA(nn.Module):
self
.
embed_dim
=
embed_dim
self
.
cross_attn
=
cross_attn
self
.
causal
=
causal
self
.
layer_idx
=
layer_idx
self
.
dwconv
=
dwconv
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
...
...
@@ -315,6 +325,8 @@ class MHA(nn.Module):
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
))
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
...
...
@@ -323,7 +335,6 @@ class MHA(nn.Module):
if
self
.
dwconv
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
...
...
@@ -335,14 +346,41 @@ class MHA(nn.Module):
groups
=
embed_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
2
*
embed_dim
,
2
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
2
*
embed_dim
)
inner_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, 1, nheads, head_dim)
"""
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
inference_kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
inference_kv_cache
else
:
inference_kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
assert
batch_end
<=
inference_kv_cache
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
sequence_end
<=
inference_kv_cache
.
shape
[
1
]
# Copy key and values.
inference_kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
inference_kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
**
kwargs
):
inference_params
=
None
,
**
kwargs
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...
...
@@ -355,6 +393,8 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
if
cu_seqlens
is
not
None
:
assert
max_seqlen
is
not
None
...
...
@@ -366,6 +406,10 @@ class MHA(nn.Module):
assert
cu_seqlens
is
None
assert
max_seqlen
is
None
assert
not
self
.
use_flash_attn
if
inference_params
is
not
None
:
assert
key_padding_mask
is
None
assert
cu_seqlens
is
None
and
max_seqlen
is
None
assert
not
self
.
dwconv
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
...
...
@@ -378,12 +422,22 @@ class MHA(nn.Module):
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
if
inference_params
is
None
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
False
if
inference_params
.
sequence_len_offset
==
0
else
None
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
...
...
@@ -401,10 +455,14 @@ class MHA(nn.Module):
'b d s -> b s d'
).
contiguous
()
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
kv
=
self
.
_update_kv_cache
(
kv
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
False
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
...
flash_attn/utils/generation.py
0 → 100644
View file @
63670fd8
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
dataclasses
import
dataclass
,
field
import
torch
from
einops
import
rearrange
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
@
dataclass
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len
:
int
max_batch_size
:
int
sequence_len_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
def
greedy_decode
(
input_ids
,
model
,
max_length
):
"""Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
scores
=
[]
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
scores
.
append
(
logits
)
next_token
=
logits
.
argmax
(
dim
=-
1
)
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
scores
.
append
(
logits
)
next_token
=
logits
.
argmax
(
dim
=-
1
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
return
GreedySearchDecoderOnlyOutput
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
scores
=
tuple
(
scores
)
)
class
GenerationMixin
:
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
):
output
=
greedy_decode
(
input_ids
,
self
,
max_length
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
tests/models/test_gpt.py
View file @
63670fd8
...
...
@@ -23,16 +23,6 @@ def test_gpt2_state_dict(model_name):
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
def
get_hf_models
(
model_name
,
config
,
dtype
):
pretrained_state_dict
=
state_dict_from_pretrained
(
model_name
)
model_hf
=
GPT2LMHeadModelHF
(
config
)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model_hf
.
cuda
().
to
(
dtype
=
dtype
)
return
model_hf
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
,
"gpt2-medium"
])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_non_optimized
(
model_name
):
...
...
tests/models/test_gpt_generation.py
0 → 100644
View file @
63670fd8
import
re
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPT2Tokenizer
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
greedy_decode
# TODO: test with rotary embedding
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
optimized
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
optimized
:
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
max_length
=
30
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
scores
.
append
(
model
(
cur_input_ids
).
logits
[:,
-
1
])
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment