Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f1e01c27
Commit
f1e01c27
authored
Jan 15, 2023
by
Tri Dao
Browse files
[Gen] Pass qkv_stride to ft_attention kernel for batched generation
parent
7c219154
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
21 deletions
+15
-21
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+14
-20
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+1
-1
No files found.
csrc/ft_attention/ft_attention.cpp
View file @
f1e01c27
...
@@ -23,17 +23,6 @@
...
@@ -23,17 +23,6 @@
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
}
// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
// if (TYPE == at::ScalarType::Half) { \
// using scalar_t = at::Half; \
// __VA_ARGS__(); \
// } else if (TYPE == at::ScalarType::Float) { \
// using scalar_t = float; \
// __VA_ARGS__(); \
// } else { \
// AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
// }
template
<
typename
T
>
template
<
typename
T
>
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
const
cudaStream_t
&
stream
);
...
@@ -66,6 +55,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
...
@@ -66,6 +55,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
const
int
timestep
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
int
rotary_embedding_dim
,
const
bool
neox_rotary_style
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
T
*
q_ptr
,
T
*
q_ptr
,
T
*
k_ptr
,
T
*
k_ptr
,
T
*
v_ptr
,
T
*
v_ptr
,
...
@@ -85,7 +75,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
...
@@ -85,7 +75,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
params
.
v_cache
=
v_cache_ptr
;
params
.
v_cache
=
v_cache_ptr
;
params
.
out
=
out_ptr
;
params
.
out
=
out_ptr
;
params
.
cache_indir
=
nullptr
;
params
.
cache_indir
=
nullptr
;
params
.
stride
=
0
;
params
.
stride
=
qkv_batch_stride
;
params
.
batch_size
=
batch_size
;
params
.
batch_size
=
batch_size
;
params
.
beam_width
=
1
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
memory_max_len
=
memory_max_seqlen
;
...
@@ -98,8 +88,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
...
@@ -98,8 +88,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
params
.
total_padding_tokens
=
nullptr
;
params
.
total_padding_tokens
=
nullptr
;
params
.
masked_tokens
=
nullptr
;
params
.
masked_tokens
=
nullptr
;
params
.
prefix_prompt_lengths
=
nullptr
;
params
.
prefix_prompt_lengths
=
nullptr
;
// params.max_prefix_prompt_length = memory_max_seqlen; // TODO: waht should this be?
params
.
max_prefix_prompt_length
=
0
;
params
.
max_prefix_prompt_length
=
0
;
// TODO: waht should this be?
params
.
relative_attention_bias
=
nullptr
;
params
.
relative_attention_bias
=
nullptr
;
params
.
relative_attention_bias_stride
=
0
;
params
.
relative_attention_bias_stride
=
0
;
params
.
cross_attention_out
=
nullptr
;
params
.
cross_attention_out
=
nullptr
;
...
@@ -127,10 +116,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -127,10 +116,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads
,
headdim
);
// TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
);
// TODO: avoid contiguous requirment by storing the stride
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
CHECK_CONTIGUOUS
(
q
);
CHECK_CONTIGUOUS
(
k
);
CHECK_CONTIGUOUS
(
v
);
int
packsize
=
k_cache
.
dtype
()
==
torch
::
kFloat32
?
4
:
8
;
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
TORCH_CHECK
(
q
.
stride
(
2
)
==
1
&&
q
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
k
.
stride
(
2
)
==
1
&&
k
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
v
.
stride
(
2
)
==
1
&&
v
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
q
.
stride
(
0
)
==
k
.
stride
(
0
)
&&
q
.
stride
(
0
)
==
v
.
stride
(
0
));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
if
(
length_per_sample_
.
has_value
())
{
if
(
length_per_sample_
.
has_value
())
{
auto
length_per_sample
=
length_per_sample_
.
value
();
auto
length_per_sample
=
length_per_sample_
.
value
();
...
@@ -146,11 +140,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -146,11 +140,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
out
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
,
timestep
,
set_params
(
params
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
neox_rotary_style
,
rotary_embedding_dim
,
neox_rotary_style
,
q
.
stride
(
0
),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
...
...
tests/models/test_gpt_generation.py
View file @
f1e01c27
...
@@ -57,7 +57,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -57,7 +57,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
max_length
=
30
# input_ids = torch.randint(0, 100, (
1
, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (
2
, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
# Slow generation for reference
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment