Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
"include/ck/utility/array.hpp" did not exist on "10bb81106072e7f9de1c7ce0ed7880e41bd9f517"
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
161 additions
and
59 deletions
+161
-59
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/generate_kernels.py
csrc/flash_attn/src/generate_kernels.py
+14
-9
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+1
-1
csrc/flash_attn/src/mask.h
csrc/flash_attn/src/mask.h
+3
-3
csrc/flash_attn/src/rotary.h
csrc/flash_attn/src/rotary.h
+1
-1
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+10
-0
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+1
-2
setup.py
setup.py
+33
-16
tests/test_flash_attn.py
tests/test_flash_attn.py
+55
-19
training/Dockerfile
training/Dockerfile
+2
-2
No files found.
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/generate_kernels.py
View file @
d562aa63
...
@@ -16,17 +16,18 @@ DTYPE_MAP = {
...
@@ -16,17 +16,18 @@ DTYPE_MAP = {
SM
=
[
80
]
# Sm80 kernels support up to
SM
=
[
80
]
# Sm80 kernels support up to
HEAD_DIMENSIONS
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
HEAD_DIMENSIONS
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
IS_CAUSAL
=
[
"false"
,
"true"
]
KERNEL_IMPL_TEMPLATE_FWD
=
"""#include "flash_fwd_launch_template.h"
KERNEL_IMPL_TEMPLATE_FWD
=
"""#include "flash_fwd_launch_template.h"
template<>
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}
, {IS_CAUSAL}
>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}
, {IS_CAUSAL}
>(params, stream);
}}
}}
"""
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
=
"""#include "flash_fwd_launch_template.h"
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
=
"""#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}
, {IS_CAUSAL}
>(Flash_fwd_params ¶ms, cudaStream_t stream);
"""
"""
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
...
@@ -43,13 +44,14 @@ class Kernel:
...
@@ -43,13 +44,14 @@ class Kernel:
sm
:
int
sm
:
int
dtype
:
str
dtype
:
str
head_dim
:
int
head_dim
:
int
is_causal
:
bool
direction
:
str
direction
:
str
@
property
@
property
def
template
(
self
)
->
str
:
def
template
(
self
)
->
str
:
if
self
.
direction
==
"fwd"
:
if
self
.
direction
==
"fwd"
:
return
KERNEL_IMPL_TEMPLATE_FWD
.
format
(
return
KERNEL_IMPL_TEMPLATE_FWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
,
IS_CAUSAL
=
self
.
is_causal
)
)
elif
self
.
direction
==
"bwd"
:
elif
self
.
direction
==
"bwd"
:
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
...
@@ -57,18 +59,21 @@ class Kernel:
...
@@ -57,18 +59,21 @@ class Kernel:
)
)
else
:
else
:
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
,
IS_CAUSAL
=
self
.
is_causal
)
)
@
property
@
property
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
f
"flash_
{
self
.
direction
}
_hdim
{
self
.
head_dim
}
_
{
self
.
dtype
}
_sm
{
self
.
sm
}
.cu"
return
f
"flash_
{
self
.
direction
}
_hdim
{
self
.
head_dim
}
_
{
self
.
dtype
}
_
{
'causal_'
if
self
.
is_causal
==
'true'
else
''
}
sm
{
self
.
sm
}
.cu"
def
get_all_kernels
()
->
List
[
Kernel
]:
def
get_all_kernels
()
->
List
[
Kernel
]:
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
for
direction
in
[
"fwd"
,
"fwd_split"
]:
for
direction
in
[
"fwd"
,
"bwd"
,
"fwd_split"
]:
for
dtype
,
head_dim
,
is_causal
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
IS_CAUSAL
,
SM
):
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
direction
=
direction
)
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
is_causal
=
is_causal
,
direction
=
direction
)
for
direction
in
[
"bwd"
]:
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
is_causal
=
"false"
,
direction
=
direction
)
def
write_kernel
(
kernel
:
Kernel
,
autogen_dir
:
Path
)
->
None
:
def
write_kernel
(
kernel
:
Kernel
,
autogen_dir
:
Path
)
->
None
:
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "cute/
algorithm/copy
.hpp"
#include "cute/
tensor
.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/layout/layout.h"
...
...
csrc/flash_attn/src/mask.h
View file @
d562aa63
...
@@ -13,7 +13,7 @@ using namespace cute;
...
@@ -13,7 +13,7 @@ using namespace cute;
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
...
@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
...
...
csrc/flash_attn/src/rotary.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include <cute/
algorithm/copy
.hpp>
#include <cute/
tensor
.hpp>
#include "utils.h"
#include "utils.h"
...
...
csrc/flash_attn/src/static_switch.h
View file @
d562aa63
...
@@ -56,6 +56,16 @@
...
@@ -56,6 +56,16 @@
#define EVENK_SWITCH BOOL_SWITCH
#define EVENK_SWITCH BOOL_SWITCH
#endif
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
[&] { \
...
...
csrc/flash_attn/src/utils.h
View file @
d562aa63
...
@@ -14,8 +14,7 @@
...
@@ -14,8 +14,7 @@
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#endif
#endif
#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
...
...
setup.py
View file @
d562aa63
...
@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD:
...
@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim32_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim32_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim64_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim64_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim96_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim96_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim128_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim128_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim160_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim160_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim192_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim192_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim224_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim224_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim256_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim256_bf16_
causal_
sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
...
@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD:
...
@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu"
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
...
@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD:
...
@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD:
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
"-DFLASHATTENTION_DISABLE_DROPOUT"
,
"-DFLASHATTENTION_DISABLE_DROPOUT"
,
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
"-DFLASHATTENTION_DISABLE_UNEVEN_K"
,
"-DFLASHATTENTION_DISABLE_UNEVEN_K"
,
# "-DFLASHATTENTION_DISABLE_LOCAL",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
]
...
...
tests/test_flash_attn.py
View file @
d562aa63
...
@@ -216,6 +216,7 @@ def attention_ref(
...
@@ -216,6 +216,7 @@ def attention_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -233,7 +234,7 @@ def attention_ref(
...
@@ -233,7 +234,7 @@ def attention_ref(
window_size: (int, int), left and right window size
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling
k
, etc.)
reorder_ops: whether to change the order of operations (scaling k instead of scaling
q
, etc.)
without changing the math. This is to estimate the numerical error from operation
without changing the math. This is to estimate the numerical error from operation
reordering.
reordering.
Output:
Output:
...
@@ -253,6 +254,10 @@ def attention_ref(
...
@@ -253,6 +254,10 @@ def attention_ref(
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
softcap
>
0
:
scores
/=
softcap
scores
=
scores
.
tanh
()
scores
*=
softcap
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
...
@@ -298,6 +303,7 @@ def attention_kvpacked_ref(
...
@@ -298,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -313,6 +319,7 @@ def attention_kvpacked_ref(
...
@@ -313,6 +319,7 @@ def attention_kvpacked_ref(
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -325,6 +332,7 @@ def attention_qkvpacked_ref(
...
@@ -325,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -340,6 +348,7 @@ def attention_qkvpacked_ref(
...
@@ -340,6 +348,7 @@ def attention_qkvpacked_ref(
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked(
...
@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
def
test_flash_attn_output
(
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
):
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
batch_size
=
4
nheads
=
9
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
if
kvpacked
:
kv
=
torch
.
randn
(
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
...
@@ -918,6 +933,7 @@ def test_flash_attn_output(
...
@@ -918,6 +933,7 @@ def test_flash_attn_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -930,6 +946,7 @@ def test_flash_attn_output(
...
@@ -930,6 +946,7 @@ def test_flash_attn_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -984,6 +1001,7 @@ def test_flash_attn_output(
...
@@ -984,6 +1001,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
q
,
...
@@ -995,6 +1013,7 @@ def test_flash_attn_output(
...
@@ -995,6 +1013,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1010,6 +1029,7 @@ def test_flash_attn_output(
...
@@ -1010,6 +1029,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1022,6 +1042,7 @@ def test_flash_attn_output(
...
@@ -1022,6 +1042,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1036,7 +1057,7 @@ def test_flash_attn_output(
...
@@ -1036,7 +1057,7 @@ def test_flash_attn_output(
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
if
kvpacked
:
if
kvpacked
:
(
(
dq
,
dq
,
...
@@ -1092,7 +1113,7 @@ def test_flash_attn_output(
...
@@ -1092,7 +1113,7 @@ def test_flash_attn_output(
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -1133,24 +1154,31 @@ def test_flash_attn_output(
...
@@ -1133,24 +1154,31 @@ def test_flash_attn_output(
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
):
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
batch_size
=
4
nheads
=
9
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
if
kvpacked
:
kv
=
torch
.
randn
(
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
...
@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output(
...
@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output(
...
@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output(
...
@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
q
,
...
@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output(
...
@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output(
...
@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output(
...
@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output(
...
@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output(
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
if
kvpacked
:
if
kvpacked
:
(
(
dq_unpad
,
dq_unpad
,
...
@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output(
...
@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output(
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.0
25
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.0
4
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache(
...
@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache(
cache_seqlens
=
torch
.
randint
(
cache_seqlens
=
torch
.
randint
(
0
if
new_kv
else
1
,
0
if
new_kv
else
1
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
(
if
new_kv
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
else
(
seqlen_k
+
1
),
if
new_kv
else
(
seqlen_k
+
1
)
),
(
batch_size
,),
(
batch_size
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
...
@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
...
@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
dq
0
,
dk
0
,
dv
0
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
for
_
in
range
(
50
):
for
_
in
range
(
50
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
assert
torch
.
equal
(
dv
,
dv
)
assert
torch
.
equal
(
dv
,
dv
0
)
assert
torch
.
equal
(
dk
,
dk
)
assert
torch
.
equal
(
dk
,
dk
0
)
assert
torch
.
equal
(
dq
,
dq
)
assert
torch
.
equal
(
dq
,
dq
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
...
...
training/Dockerfile
View file @
d562aa63
...
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
...
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.
5.7
RUN
pip
install
flash-attn
==
2.
6.0
# Install CUDA extensions for fused dense
# Install CUDA extensions for fused dense
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.
5.7
#subdirectory
=
csrc/fused_dense_lib
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.
6.0
#subdirectory
=
csrc/fused_dense_lib
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment