Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a43fbbf1
Commit
a43fbbf1
authored
Apr 22, 2024
by
Woosuk Kwon
Browse files
Merge remote-tracking branch 'tri/main'
parents
498cd8c3
85881f54
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
110 additions
and
63 deletions
+110
-63
.github/workflows/publish.yml
.github/workflows/publish.yml
+2
-2
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+12
-4
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+51
-54
setup.py
setup.py
+5
-0
tests/test_rotary.py
tests/test_rotary.py
+37
-0
training/Dockerfile
training/Dockerfile
+2
-2
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+1
-1
No files found.
.github/workflows/publish.yml
View file @
a43fbbf1
...
@@ -44,7 +44,7 @@ jobs:
...
@@ -44,7 +44,7 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os
:
[
ubuntu-20.04
]
os
:
[
ubuntu-20.04
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
torch-version
:
[
'
1.12.1'
,
'
1.13.1'
,
'
2.0.1'
,
'
2.1.2'
,
'
2.2.0'
,
'
2.3.0.dev20240
105
'
]
torch-version
:
[
'
1.12.1'
,
'
1.13.1'
,
'
2.0.1'
,
'
2.1.2'
,
'
2.2.0'
,
'
2.3.0.dev20240
207
'
]
cuda-version
:
[
'
11.8.0'
,
'
12.2.2'
]
cuda-version
:
[
'
11.8.0'
,
'
12.2.2'
]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
...
@@ -63,7 +63,7 @@ jobs:
...
@@ -63,7 +63,7 @@ jobs:
python-version
:
'
3.7'
python-version
:
'
3.7'
-
torch-version
:
'
2.2.0'
-
torch-version
:
'
2.2.0'
python-version
:
'
3.7'
python-version
:
'
3.7'
-
torch-version
:
'
2.3.0.dev20240
105
'
-
torch-version
:
'
2.3.0.dev20240
207
'
python-version
:
'
3.7'
python-version
:
'
3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
# Pytorch <= 2.0 only supports CUDA <= 11.8
-
torch-version
:
'
1.12.1'
-
torch-version
:
'
1.12.1'
...
...
csrc/flash_attn/flash_api.cpp
View file @
a43fbbf1
...
@@ -205,7 +205,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
...
@@ -205,7 +205,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
params
.
num_splits
=
num_splits
;
params
.
num_splits
=
num_splits
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
if
(
num_splits
<
1
)
{
if
(
num_splits
<
1
)
{
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
128
);
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
*
2
,
num_n_blocks
,
128
);
}
}
if
(
params
.
num_splits
>
1
)
{
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
max_seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
max_seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
...
@@ -295,8 +296,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -295,8 +296,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
seqlenq_ngroups_swapped
=
seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
ngroups
=
num_heads
/
num_heads_k
;
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
seqlen_q
=
ngroups
;
seqlen_q
=
ngroups
;
num_heads
=
num_heads_k
;
num_heads
=
num_heads_k
;
...
@@ -323,7 +324,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -323,7 +324,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
CHECK_DEVICE
(
out
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
out
,
batch_size
,
sizes
[
1
],
sizes
[
2
],
head_size_og
);
if
(
seqlenq_ngroups_swapped
)
{
out
=
out
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
);
}
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
out
=
torch
::
empty_like
(
q_padded
);
...
@@ -494,8 +498,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -494,8 +498,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
// H/t Daniel Haziza
const
int
seqlenq_ngroups_swapped
=
max_seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
seqlenq_ngroups_swapped
=
max_seqlen_q
==
1
&&
num_heads
>
num_heads_k
&&
window_size_left
<
0
&&
window_size_right
<
0
&&
p_dropout
==
0.
f
&&
head_size_og
%
8
==
0
&&
!
alibi_slopes_
.
has_value
();
const
int
ngroups
=
num_heads
/
num_heads_k
;
if
(
seqlenq_ngroups_swapped
)
{
if
(
seqlenq_ngroups_swapped
)
{
const
int
ngroups
=
num_heads
/
num_heads_k
;
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
).
reshape
({
batch_size
*
ngroups
,
num_heads_k
,
head_size_og
});
q
=
q
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
).
reshape
({
batch_size
*
ngroups
,
num_heads_k
,
head_size_og
});
max_seqlen_q
=
ngroups
;
max_seqlen_q
=
ngroups
;
num_heads
=
num_heads_k
;
num_heads
=
num_heads_k
;
...
@@ -550,6 +554,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -550,6 +554,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE
(
out
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
total_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
out
,
total_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
out
,
sizes
[
0
],
sizes
[
1
],
head_size_og
);
if
(
seqlenq_ngroups_swapped
)
{
out
=
out
.
reshape
({
batch_size
,
num_heads_k
,
ngroups
,
head_size_og
}).
transpose
(
1
,
2
).
reshape
({
batch_size
*
ngroups
,
num_heads_k
,
head_size_og
});
}
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
}
else
{
out
=
torch
::
empty_like
(
q_padded
);
out
=
torch
::
empty_like
(
q_padded
);
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
a43fbbf1
...
@@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
// Otherwise we might read OOB elements from gK and gV.
if
((
Is_causal
||
Is_local
||
!
Is_even_MN
)
&&
n_block_max
<=
n_block_min
)
{
if
((
Is_causal
||
Is_local
||
!
Is_even_MN
)
&&
n_block_max
<=
n_block_min
)
{
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)),
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
index_t
row_offset_p
=
((
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
const
index_t
row_offset_p
=
((
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
+
m_block
*
kBlockM
)
*
params
.
seqlen_k_rounded
+
(
n_block_max
-
1
)
*
kBlockN
;
+
m_block
*
kBlockM
)
*
params
.
seqlen_k_rounded
+
(
n_block_max
-
1
)
*
kBlockN
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
mQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
+
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)),
make_stride
(
params
.
q_row_stride
,
_1
{}));
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
make_stride
(
params
.
q_row_stride
,
params
.
q_head_stride
,
_1
{}));
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Tensor
gQ
=
local_tile
(
mQ
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Tensor
mK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
+
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)),
make_stride
(
params
.
v_row_stride
,
_1
{}));
make_shape
(
binfo
.
actual_seqlen_k
,
params
.
h_k
,
params
.
d
),
make_stride
(
params
.
k_row_stride
,
params
.
k_head_stride
,
_1
{}));
Tensor
gK
=
local_tile
(
mK
(
_
,
bidh
/
params
.
h_h_k_ratio
,
_
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
_
,
0
));
// (kBlockN, kHeadDim, nblocksN)
Tensor
mV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)),
make_shape
(
binfo
.
actual_seqlen_k
,
params
.
h_k
,
params
.
d
),
make_stride
(
params
.
v_row_stride
,
params
.
v_head_stride
,
_1
{}));
Tensor
gV
=
local_tile
(
mV
(
_
,
bidh
/
params
.
h_h_k_ratio
,
_
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
_
,
0
));
// (kBlockN, kHeadDim, nblocksN)
Tensor
gP
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
p_ptr
)
+
row_offset_p
),
Tensor
gP
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
p_ptr
)
+
row_offset_p
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{},
make_stride
(
params
.
seqlen_k_rounded
,
_1
{}));
make_stride
(
params
.
seqlen_k_rounded
,
_1
{}));
...
@@ -146,9 +150,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -146,9 +150,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K
, nblocksN
)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K
, nblocksN
)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
@@ -241,7 +245,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -241,7 +245,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
n_block
=
n_block_max
-
1
;
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
(
_
,
_
,
_
,
n_block
)
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
...
@@ -282,12 +286,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -282,12 +286,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Advance gV
// Advance gV
if
(
masking_step
>
0
)
{
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
(
_
,
_
,
_
,
n_block
),
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
}
else
{
// Clear the smem tiles to account for predicated off loads
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_tiled_copy_QKV
,
tVgV
(
_
,
_
,
_
,
n_block
)
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
}
}
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -305,9 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -305,9 +308,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
(
n_block
>
n_block_min
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
(
_
,
_
,
_
,
n_block
-
1
),
tKsK
,
tKVcKV
,
tKVpKV
);
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -355,9 +356,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -355,9 +356,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_s
);
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
// Advance gV
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
(
_
,
_
,
_
,
n_block
),
tVsV
,
tKVcKV
,
tKVpKV
);
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
...
@@ -368,9 +367,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -368,9 +367,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
if
(
n_block
>
n_block_min
)
{
if
(
n_block
>
n_block_min
)
{
// Advance gK
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
(
_
,
_
,
_
,
n_block
-
1
),
tKsK
,
tKVcKV
,
tKVpKV
);
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -422,14 +419,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -422,14 +419,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
cute
::
copy
(
smem_tiled_copy_O
,
taccOrO
,
taccOsO
);
cute
::
copy
(
smem_tiled_copy_O
,
taccOrO
,
taccOsO
);
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)),
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -556,8 +555,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -556,8 +555,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
// We move K and V to the last block.
const
int
bidb_cache
=
params
.
cache_batch_idx
==
nullptr
?
bidb
:
params
.
cache_batch_idx
[
bidb
];
const
int
bidb_cache
=
params
.
cache_batch_idx
==
nullptr
?
bidb
:
params
.
cache_batch_idx
[
bidb
];
const
int
*
block_table
=
params
.
block_table
==
nullptr
?
nullptr
:
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
*
block_table
=
params
.
block_table
==
nullptr
?
nullptr
:
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
...
@@ -573,9 +570,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -573,9 +570,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
mQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_shape
(
binfo
.
actual_seqlen_q
,
params
.
h
,
params
.
d
),
make_stride
(
params
.
q_row_stride
,
_1
{}));
make_stride
(
params
.
q_row_stride
,
params
.
q_head_stride
,
_1
{}));
Tensor
gQ
=
local_tile
(
mQ
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
make_stride
(
params
.
k_row_stride
,
_1
{}));
...
@@ -1051,8 +1050,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1051,8 +1050,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
);
// __syncthreads();
// if (cute::thread0()) { print(tOgOaccum); }
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
setup.py
View file @
a43fbbf1
...
@@ -200,6 +200,11 @@ if not SKIP_CUDA_BUILD:
...
@@ -200,6 +200,11 @@ if not SKIP_CUDA_BUILD:
# "--ptxas-options=-v",
# "--ptxas-options=-v",
# "--ptxas-options=-O2",
# "--ptxas-options=-O2",
# "-lineinfo",
# "-lineinfo",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
]
+
generator_flag
+
generator_flag
+
cc_flag
+
cc_flag
...
...
tests/test_rotary.py
View file @
a43fbbf1
...
@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
...
@@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x_grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
assert
torch
.
allclose
(
x_grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
def
test_compilation_count
():
batch_size
=
1
headdim
=
128
device
=
"cuda"
dtype
=
torch
.
float16
torch
.
manual_seed
(
42
)
from
triton.runtime.jit
import
JITFunction
from
flash_attn.ops.triton.rotary
import
rotary_kernel
compilation_count
=
0
def
count_compilations
(
*
args
,
**
kwargs
):
nonlocal
compilation_count
compilation_count
+=
1
old_cache_func
=
JITFunction
.
cache_hook
try
:
rotary_kernel
.
cache
.
clear
()
JITFunction
.
cache_hook
=
count_compilations
for
seqlen
in
(
128
,
256
):
for
nheads
in
(
4
,
32
):
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
)
x
.
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
headdim
,
device
,
dtype
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
)
out
.
backward
(
torch
.
randn_like
(
out
))
# Only two kernels are expected to be compiled:
# * for the forward pass (conjugate=False)
# * for the backward pass (conjugate=True)
assert
compilation_count
==
2
finally
:
JITFunction
.
cache_hook
=
old_cache_func
training/Dockerfile
View file @
a43fbbf1
...
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
...
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.5.
6
RUN
pip
install
flash-attn
==
2.5.
7
# Install CUDA extensions for fused dense
# Install CUDA extensions for fused dense
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.5.
6
#subdirectory
=
csrc/fused_dense_lib
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.5.
7
#subdirectory
=
csrc/fused_dense_lib
vllm_flash_attn/__init__.py
View file @
a43fbbf1
__version__
=
"2.5.
6
"
__version__
=
"2.5.
7
"
from
vllm_flash_attn.flash_attn_interface
import
(
from
vllm_flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_func
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment