Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
7c219154
Commit
7c219154
authored
Jan 15, 2023
by
Tri Dao
Browse files
[Gen] Make generation work with Tensor Parallel
parent
d5098324
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
374 additions
and
146 deletions
+374
-146
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+6
-0
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+33
-20
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+87
-54
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+118
-64
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+5
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+125
-6
No files found.
csrc/ft_attention/ft_attention.cpp
View file @
7c219154
#include <torch/extension.h>
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention.h"
...
@@ -138,6 +140,10 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -138,6 +140,10 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
}
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
out
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
out
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
...
...
flash_attn/models/gpt.py
View file @
7c219154
...
@@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
...
@@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.generation
import
GenerationMixin
...
@@ -146,17 +146,23 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -146,17 +146,23 @@ class GPTPreTrainedModel(nn.Module):
self
.
config
=
config
self
.
config
=
config
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
model_name
,
config
,
*
args
,
strict
=
True
,
device
=
None
,
**
kwargs
):
def
from_pretrained
(
cls
,
model_name
,
config
,
*
args
,
strict
=
True
,
device
=
None
,
dtype
=
None
,
world_size
=
1
,
rank
=
0
,
**
kwargs
):
"""
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
"""
"""
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
args
,
device
=
device
,
**
kwargs
)
model
=
cls
(
config
,
*
args
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
load_return
=
model
.
load_state_dict
(
state_dict
=
remap_state_dict_gpt2
(
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
,
device
=
device
),
config
),
# If we're going to shard the model, then don't load fp32 weights to GPU.
strict
=
strict
state_dict_from_pretrained
(
model_name
,
device
=
device
if
world_size
==
1
else
None
,
dtype
=
dtype
),
config
)
)
if
world_size
>
1
:
state_dict
=
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
)
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
load_return
=
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
logger
.
info
(
load_return
)
logger
.
info
(
load_return
)
return
model
return
model
...
@@ -190,17 +196,16 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -190,17 +196,16 @@ class GPTModel(GPTPreTrainedModel):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'sqrelu'
]
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
*
pad_vocab_size_multiple
)
-
(
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
))
if
process_group
is
None
:
if
process_group
is
None
:
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
**
factory_kwargs
)
config
.
max_position_embeddings
,
**
factory_kwargs
)
else
:
else
:
self
.
embeddings
=
ParallelGPT2Embeddings
(
self
.
embeddings
=
ParallelGPT2Embeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
max_position_embeddings
,
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
process_group
=
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
process_group
=
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
**
factory_kwargs
)
)
...
@@ -248,8 +253,9 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -248,8 +253,9 @@ class GPTModel(GPTPreTrainedModel):
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
emb_drop
(
hidden_states
)
.
float
()
residual
=
self
.
emb_drop
(
hidden_states
)
hidden_states
=
self
.
ln_0
(
residual
.
to
(
dtype
=
self
.
ln_0
.
weight
.
dtype
))
hidden_states
=
self
.
ln_0
(
residual
.
to
(
dtype
=
self
.
ln_0
.
weight
.
dtype
))
residual
=
residual
.
float
()
else
:
else
:
hidden_states
,
residual
=
dropout_add_layer_norm
(
hidden_states
,
residual
=
dropout_add_layer_norm
(
hidden_states
,
None
,
self
.
ln_0
.
weight
,
self
.
ln_0
.
bias
,
hidden_states
,
None
,
self
.
ln_0
.
weight
,
self
.
ln_0
.
bias
,
...
@@ -272,13 +278,16 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -272,13 +278,16 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
if
process_group
is
None
:
if
process_group
is
None
:
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
else
:
else
:
if
ColumnParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
:
raise
ImportError
(
'fused_dense_lib is not installed'
)
raise
ImportError
(
'fused_dense_lib is not installed'
)
self
.
lm_head
=
ColumnParallelLinear
(
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
config
.
vocab_size
,
process_group
,
bias
=
False
,
config
.
n_embd
,
vocab_size
,
process_group
,
bias
=
False
,
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
)
)
# Initialize weights and apply final processing
# Initialize weights and apply final processing
...
@@ -299,6 +308,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -299,6 +308,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
inference_params
=
inference_params
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
# During inference, we want the full logit for sampling
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
inference_params
is
not
None
:
lm_logits
,
_
=
all_gather_raw
(
lm_logits
,
self
.
lm_head
.
process_group
)
lm_logits
=
rearrange
(
lm_logits
,
'(n b) s d -> b s (n d)'
,
b
=
hidden_states
.
shape
[
0
])
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
return
CausalLMOutput
(
logits
=
lm_logits
)
...
@@ -310,8 +323,10 @@ def remap_state_dict_gpt2(state_dict, config):
...
@@ -310,8 +323,10 @@ def remap_state_dict_gpt2(state_dict, config):
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'wte.weight'
)
word_embeddings
=
state_dict
.
pop
(
'wte.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
config
.
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
...
@@ -365,10 +380,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -365,10 +380,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
with tensor parallel.
"""
"""
vocab_size
=
config
.
vocab_size
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
config
.
vocab_size
%
config
.
pad_vocab_size_multiple
!=
0
:
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
+=
(
config
.
pad_vocab_size_multiple
-
(
config
.
vocab_size
%
config
.
pad_vocab_size_multiple
))
assert
vocab_size
%
world_size
==
0
assert
vocab_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
...
...
flash_attn/modules/mha.py
View file @
7c219154
...
@@ -289,6 +289,60 @@ class LinearResidual(nn.Linear):
...
@@ -289,6 +289,60 @@ class LinearResidual(nn.Linear):
return
super
().
forward
(
input
),
input
return
super
().
forward
(
input
),
input
def
_update_kv_cache
(
kv
,
inference_params
,
layer_idx
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
# Pre-allocate memory for key-values for inference.
num_heads
,
head_dim
=
kv
.
shape
[
-
2
:]
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
num_heads
,
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
else
:
if
not
inference_params
.
fused_ft_kernel
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
else
:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
kv_cache
=
None
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
# Copy key and values.
if
not
inference_params
.
fused_ft_kernel
:
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv
else
:
assert
inference_params
.
sequence_len_offset
==
0
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
if
kv_cache
is
not
None
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
(
k_cache
,
v_cache
)
else
:
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
)
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
1
],
'b s h d -> b h s d'
)
return
kv
class
MHA
(
nn
.
Module
):
class
MHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention
"""
"""
...
@@ -363,54 +417,7 @@ class MHA(nn.Module):
...
@@ -363,54 +417,7 @@ class MHA(nn.Module):
"""
"""
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
kv_cache
else
:
if
not
inference_params
.
fused_ft_kernel
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
else
:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
kv_cache
=
None
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
# Copy key and values.
if
not
inference_params
.
fused_ft_kernel
:
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
return
kv
else
:
assert
inference_params
.
sequence_len_offset
==
0
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
if
kv_cache
is
not
None
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
k_cache
,
v_cache
)
else
:
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
)
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
1
],
'b s h d -> b h s d'
)
return
kv
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
inference_params
=
None
,
**
kwargs
):
...
@@ -473,6 +480,7 @@ class MHA(nn.Module):
...
@@ -473,6 +480,7 @@ class MHA(nn.Module):
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
assert
ft_attention
is
not
None
context
=
ft_attention
.
single_query_attention
(
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
...
@@ -541,13 +549,16 @@ class ParallelMHA(nn.Module):
...
@@ -541,13 +549,16 @@ class ParallelMHA(nn.Module):
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
# output projection always have the bias (for now)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
seqlen
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
"""
Arguments:
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
...
@@ -561,12 +572,34 @@ class ParallelMHA(nn.Module):
...
@@ -561,12 +572,34 @@ class ParallelMHA(nn.Module):
else
:
else
:
qkv
=
rearrange
(
qkv
,
'(b s) (three h d) -> b s three h d'
,
s
=
seqlen
,
three
=
3
,
qkv
=
rearrange
(
qkv
,
'(b s) (three h d) -> b s three h d'
,
s
=
seqlen
,
three
=
3
,
d
=
self
.
head_dim
)
d
=
self
.
head_dim
)
if
self
.
rotary_emb_dim
>
0
:
if
inference_params
is
None
:
qkv
=
self
.
rotary_emb
(
qkv
)
if
self
.
rotary_emb_dim
>
0
:
if
not
self
.
checkpointing
:
qkv
=
self
.
rotary_emb
(
qkv
)
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
if
(
not
inference_params
.
fused_ft_kernel
)
or
inference_params
.
sequence_len_offset
==
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
=
qkv
[:,
:,
0
]
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
kv
=
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
,
self
.
layer_idx
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
if
seqlen
is
None
:
if
seqlen
is
None
:
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
else
:
else
:
...
...
flash_attn/utils/generation.py
View file @
7c219154
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
typing
import
Optional
from
typing
import
Optional
,
Union
,
Sequence
,
Callable
import
gc
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
...
@@ -70,7 +71,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...
@@ -70,7 +71,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
vocab_size
=
None
,
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
"""Decoding, either greedy or with top-k or top-p sampling.
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
@@ -85,18 +86,30 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -85,18 +86,30 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores: tuples of (batch, vocab_size)
scores: tuples of (batch, vocab_size)
"""
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
if
cg
:
fused_ft_kernel
=
fused_ft_kernel
)
assert
fused_ft_kernel
if
not
hasattr
(
model
,
'_decoding_cache'
):
model
.
_decoding_cache
=
None
model
.
_decoding_cache
=
update_graph_cache
(
model
,
model
.
_decoding_cache
,
batch_size
,
seqlen_og
,
max_length
,
tensor_parallel
=
tensor_parallel
)
inference_params
=
model
.
_decoding_cache
.
inference_params
inference_params
.
max_sequence_len
=
max_length
inference_params
.
max_batch_size
=
batch_size
inference_params
.
sequence_len_offset
=
0
else
:
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
scores
=
[]
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sequences
=
[
next_token
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
inference_params
.
sequence_len_offset
=
seqlen_og
if
cg
:
assert
fused_ft_kernel
run
,
cg_cache
=
capture_cg
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_length
)
if
timing
:
if
timing
:
start
=
time
.
time
()
start
=
time
.
time
()
while
True
:
while
True
:
...
@@ -106,8 +119,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -106,8 +119,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
inference_params
=
inference_params
).
logits
[:,
-
1
]
else
:
else
:
logits
=
run
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
,
logits
=
model
.
_decoding_cache
.
run
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
,
inference_params
.
sequence_len_offset
)
inference_params
.
sequence_len_offset
)
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
sequences
.
append
(
next_token
)
sequences
.
append
(
next_token
)
...
@@ -115,6 +130,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -115,6 +130,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
break
if
timing
:
if
timing
:
torch
.
cuda
.
synchronize
()
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
...
@@ -134,8 +150,18 @@ class GenerationMixin:
...
@@ -134,8 +150,18 @@ class GenerationMixin:
return
output
if
return_dict_in_generate
else
output
.
sequences
return
output
if
return_dict_in_generate
else
output
.
sequences
CgKey
=
namedtuple
(
'CgKey'
,
[
'batch_size'
,
'seqlen_type'
,
'max_length'
])
def
allocate_kv_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
CgVal
=
namedtuple
(
'CgVal'
,
[
'graph'
,
'input_ids'
,
'position_ids'
,
'lengths'
,
'logits'
])
device
,
dtype
=
torch
.
float16
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
k_cache_shape
=
(
max_batch_size
,
nheads
,
headdim
//
packsize
,
max_seqlen
,
packsize
)
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
return
{
i
:
(
torch
.
empty
(
k_cache_shape
,
device
=
device
,
dtype
=
dtype
),
torch
.
empty
(
v_cache_shape
,
device
=
device
,
dtype
=
dtype
))
for
i
in
layers
}
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
...
@@ -152,63 +178,91 @@ def seqlen_type_to_seqlen(seqlen_type: int) -> int:
...
@@ -152,63 +178,91 @@ def seqlen_type_to_seqlen(seqlen_type: int) -> int:
return
1
if
seqlen_type
==
0
else
(
32
if
seqlen_type
==
1
else
2048
)
return
1
if
seqlen_type
==
0
else
(
32
if
seqlen_type
==
1
else
2048
)
def
capture_cg
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_length
,
copy_output
=
False
):
@
dataclass
"""Build a cache of cuda graphs for decoding.
class
DecodingCGCache
:
Arguments:
max_batch_size
:
int
=
0
model: a GPTLMHeadModel
max_seqlen
:
int
=
0
batch_size: int
device
=
None
seqlen_og: int. Length of the prompt.
dtype
=
None
max_length: int
callables
:
dict
=
field
(
default_factory
=
dict
)
TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
mempool
=
None
has to own the k_cache and v_cache?
inference_params
:
Optional
[
InferenceParams
]
=
None
Here we assume that the model already has inference_params from the prompt processing.
run
:
Optional
[
Callable
]
=
None
"""
assert
max_length
>
seqlen_og
cg_cache
:
dict
[
CgKey
,
CgVal
]
=
{}
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
device
=
param_example
.
device
if
dtype
is
None
:
dtype
=
param_example
.
dtype
if
((
device
,
dtype
)
!=
(
cache
.
device
,
cache
.
dtype
)
or
batch_size
>
cache
.
max_batch_size
or
max_seqlen
>
cache
.
max_seqlen
):
# Invalidate the cache
cache
.
callables
=
{}
cache
.
mempool
=
None
cache
.
inference_params
=
None
gc
.
collect
()
cache
.
device
,
cache
.
dtype
=
device
,
dtype
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
headdim
=
getattr
(
model
.
config
,
'head_dim'
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
)
kv_cache
=
allocate_kv_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
max_sequence_len
=
max_seqlen
,
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
kv_cache
,
fused_ft_kernel
=
True
,
lengths_per_sample
=
lengths_per_sample
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_seqlen
)
+
1
):
if
s_type
not
in
cache
.
callables
:
seqlen
=
min
(
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
seqlen_og
,
seqlen
,
mempool
=
cache
.
mempool
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
return
cache
.
callables
[
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
sequence_length_offset
=
0
# Reset so it's not confusing
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_seqlen
,
mempool
=
None
):
assert
max_seqlen
>=
seqlen_og
device
=
next
(
iter
(
model
.
parameters
())).
device
device
=
next
(
iter
(
model
.
parameters
())).
device
sequence_length_offset_og
=
inference_params
.
sequence_len_offset
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
inference_params
.
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
inference_params
.
lengths_per_sample
[:]
=
seqlen_og
device
=
device
)
# Warmup before capture
memory_pool
=
None
s
=
torch
.
cuda
.
Stream
()
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_length
)
+
1
):
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
seqlen
=
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
))
with
torch
.
cuda
.
stream
(
s
):
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
for
_
in
range
(
2
):
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
inference_params
.
lengths_per_sample
[:]
=
seqlen
inference_params
.
sequence_len_offset
=
seqlen
g
=
torch
.
cuda
.
CUDAGraph
()
# Warmup before capture
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
2
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with
torch
.
cuda
.
graph
(
g
,
pool
=
memory_pool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
memory_pool
is
None
:
s
.
synchronize
()
memory_pool
=
g
.
pool
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
cg_cache
[
CgKey
(
batch_size
,
s_type
,
max_length
)]
=
CgVal
(
# Captures the graph
g
,
input_ids
,
position_ids
,
inference_params
.
lengths_per_sample
,
logits
# To allow capture, automatically sets a side stream as the current stream in the context
)
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
cg_val
=
cg_cache
[
CgKey
(
batch_size
,
seqlen_to_seqlen_type
(
seqlen
),
max_length
)]
inference_params
.
lengths_per_sample
=
cg_val
.
lengths
inference_params
.
lengths_per_sample
[:]
=
seqlen
inference_params
.
lengths_per_sample
[:]
=
seqlen
cg_val
.
input_ids
.
copy_
(
new_input_ids
)
input_ids
.
copy_
(
new_input_ids
)
cg_val
.
position_ids
.
copy_
(
new_position_ids
)
position_ids
.
copy_
(
new_position_ids
)
cg_val
.
graph
.
replay
()
graph
.
replay
()
output
=
cg_val
.
logits
return
logits
return
output
.
clone
()
if
copy_output
else
output
inference_params
.
sequence_len_offset
=
sequence_length_offset_og
return
run
,
cg_cache
return
run
flash_attn/utils/pretrained.py
View file @
7c219154
...
@@ -4,5 +4,8 @@ from transformers.utils import WEIGHTS_NAME
...
@@ -4,5 +4,8 @@ from transformers.utils import WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
from
transformers.utils.hub
import
cached_file
def
state_dict_from_pretrained
(
model_name
,
device
=
None
):
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
tests/models/test_gpt_generation.py
View file @
7c219154
import
os
import
re
import
re
import
torch
import
torch
...
@@ -11,12 +12,13 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
...
@@ -11,12 +12,13 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.distributed
import
all_gather_raw
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Fals
e])
# @pytest.mark.parametrize('optimized', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
...
@@ -40,19 +42,20 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -40,19 +42,20 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
# if not rotary, we load the weight from HF but ignore the position embeddings.
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
)
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
,
model
=
model
.
to
(
dtype
=
dtype
)
dtype
=
dtype
)
model
.
eval
()
model
.
eval
()
if
not
rotary
:
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
(
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
...
@@ -100,3 +103,119 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -100,3 +103,119 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
True
])
# @pytest.mark.parametrize('rotary', [False, True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_tensor_parallel
(
model_name
,
rotary
,
fused_ft_kernel
,
world_size
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype
=
torch
.
float16
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
from
apex.transformer
import
parallel_state
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
strict
=
not
rotary
,
device
=
device
,
dtype
=
dtype
,
process_group
=
process_group
,
world_size
=
world_size
,
rank
=
rank
)
model
.
eval
()
if
not
rotary
:
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences
=
[]
scores
=
[]
cur_input_ids
=
input_ids
with
torch
.
inference_mode
():
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
for
_
in
range
(
input_ids
.
shape
[
1
]
+
1
,
max_length
):
cur_input_ids
=
torch
.
cat
([
cur_input_ids
,
rearrange
(
sequences
[
-
1
],
'b -> b 1'
)],
dim
=-
1
)
logits
,
_
=
all_gather_raw
(
model
(
cur_input_ids
).
logits
[:,
-
1
],
process_group
)
logits
=
rearrange
(
logits
,
'(n b) d -> b (n d)'
,
b
=
input_ids
.
shape
[
0
])[...,
:
config
.
vocab_size
]
scores
.
append
(
logits
)
sequences
.
append
(
scores
[
-
1
].
argmax
(
dim
=-
1
))
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
)
scores
=
tuple
(
scores
)
print
(
sequences
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out_cg
.
sequences
)
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_ref
=
model_ref
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
print
(
f
'Scores max diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'Scores mean diff:
{
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
mean
().
item
()
}
'
)
assert
torch
.
all
(
out
.
sequences
==
sequences
)
assert
torch
.
allclose
(
torch
.
stack
(
out
.
scores
,
dim
=
1
),
torch
.
stack
(
scores
,
dim
=
1
),
rtol
=
rtol
,
atol
=
atol
)
if
not
rotary
:
assert
torch
.
all
(
out
.
sequences
==
out_ref
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment