Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
161 additions
and
59 deletions
+161
-59
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/generate_kernels.py
csrc/flash_attn/src/generate_kernels.py
+14
-9
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+1
-1
csrc/flash_attn/src/mask.h
csrc/flash_attn/src/mask.h
+3
-3
csrc/flash_attn/src/rotary.h
csrc/flash_attn/src/rotary.h
+1
-1
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+10
-0
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+1
-2
setup.py
setup.py
+33
-16
tests/test_flash_attn.py
tests/test_flash_attn.py
+55
-19
training/Dockerfile
training/Dockerfile
+2
-2
No files found.
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
View file @
d562aa63
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/generate_kernels.py
View file @
d562aa63
...
@@ -16,17 +16,18 @@ DTYPE_MAP = {
...
@@ -16,17 +16,18 @@ DTYPE_MAP = {
SM
=
[
80
]
# Sm80 kernels support up to
SM
=
[
80
]
# Sm80 kernels support up to
HEAD_DIMENSIONS
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
HEAD_DIMENSIONS
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
IS_CAUSAL
=
[
"false"
,
"true"
]
KERNEL_IMPL_TEMPLATE_FWD
=
"""#include "flash_fwd_launch_template.h"
KERNEL_IMPL_TEMPLATE_FWD
=
"""#include "flash_fwd_launch_template.h"
template<>
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}
, {IS_CAUSAL}
>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}
, {IS_CAUSAL}
>(params, stream);
}}
}}
"""
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
=
"""#include "flash_fwd_launch_template.h"
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
=
"""#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}
, {IS_CAUSAL}
>(Flash_fwd_params ¶ms, cudaStream_t stream);
"""
"""
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
...
@@ -43,13 +44,14 @@ class Kernel:
...
@@ -43,13 +44,14 @@ class Kernel:
sm
:
int
sm
:
int
dtype
:
str
dtype
:
str
head_dim
:
int
head_dim
:
int
is_causal
:
bool
direction
:
str
direction
:
str
@
property
@
property
def
template
(
self
)
->
str
:
def
template
(
self
)
->
str
:
if
self
.
direction
==
"fwd"
:
if
self
.
direction
==
"fwd"
:
return
KERNEL_IMPL_TEMPLATE_FWD
.
format
(
return
KERNEL_IMPL_TEMPLATE_FWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
,
IS_CAUSAL
=
self
.
is_causal
)
)
elif
self
.
direction
==
"bwd"
:
elif
self
.
direction
==
"bwd"
:
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
...
@@ -57,18 +59,21 @@ class Kernel:
...
@@ -57,18 +59,21 @@ class Kernel:
)
)
else
:
else
:
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
,
IS_CAUSAL
=
self
.
is_causal
)
)
@
property
@
property
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
f
"flash_
{
self
.
direction
}
_hdim
{
self
.
head_dim
}
_
{
self
.
dtype
}
_sm
{
self
.
sm
}
.cu"
return
f
"flash_
{
self
.
direction
}
_hdim
{
self
.
head_dim
}
_
{
self
.
dtype
}
_
{
'causal_'
if
self
.
is_causal
==
'true'
else
''
}
sm
{
self
.
sm
}
.cu"
def
get_all_kernels
()
->
List
[
Kernel
]:
def
get_all_kernels
()
->
List
[
Kernel
]:
for
direction
in
[
"fwd"
,
"fwd_split"
]:
for
dtype
,
head_dim
,
is_causal
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
IS_CAUSAL
,
SM
):
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
is_causal
=
is_causal
,
direction
=
direction
)
for
direction
in
[
"bwd"
]:
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
for
direction
in
[
"fwd"
,
"bwd"
,
"fwd_split"
]:
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
is_causal
=
"false"
,
direction
=
direction
)
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
direction
=
direction
)
def
write_kernel
(
kernel
:
Kernel
,
autogen_dir
:
Path
)
->
None
:
def
write_kernel
(
kernel
:
Kernel
,
autogen_dir
:
Path
)
->
None
:
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "cute/
algorithm/copy
.hpp"
#include "cute/
tensor
.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/layout/layout.h"
...
...
csrc/flash_attn/src/mask.h
View file @
d562aa63
...
@@ -13,7 +13,7 @@ using namespace cute;
...
@@ -13,7 +13,7 @@ using namespace cute;
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
...
@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
...
...
csrc/flash_attn/src/rotary.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include <cute/
algorithm/copy
.hpp>
#include <cute/
tensor
.hpp>
#include "utils.h"
#include "utils.h"
...
...
csrc/flash_attn/src/static_switch.h
View file @
d562aa63
...
@@ -56,6 +56,16 @@
...
@@ -56,6 +56,16 @@
#define EVENK_SWITCH BOOL_SWITCH
#define EVENK_SWITCH BOOL_SWITCH
#endif
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
[&] { \
...
...
csrc/flash_attn/src/utils.h
View file @
d562aa63
...
@@ -14,8 +14,7 @@
...
@@ -14,8 +14,7 @@
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#endif
#endif
#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
...
...
setup.py
View file @
d562aa63
...
@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD:
...
@@ -151,22 +151,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim32_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim32_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim64_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim64_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim96_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim96_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim128_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim128_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim160_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim160_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim192_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim192_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim224_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim224_bf16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim256_fp16_
causal_
sm80.cu"
,
#
"csrc/flash_attn/src/flash_
b
wd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_
f
wd_hdim256_bf16_
causal_
sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
...
@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD:
...
@@ -183,6 +183,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu"
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
...
@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD:
...
@@ -203,6 +219,7 @@ if not SKIP_CUDA_BUILD:
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_BACKWARD",
"-DFLASHATTENTION_DISABLE_DROPOUT"
,
"-DFLASHATTENTION_DISABLE_DROPOUT"
,
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
"-DFLASHATTENTION_DISABLE_UNEVEN_K"
,
"-DFLASHATTENTION_DISABLE_UNEVEN_K"
,
# "-DFLASHATTENTION_DISABLE_LOCAL",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
]
...
...
tests/test_flash_attn.py
View file @
d562aa63
...
@@ -216,6 +216,7 @@ def attention_ref(
...
@@ -216,6 +216,7 @@ def attention_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -233,7 +234,7 @@ def attention_ref(
...
@@ -233,7 +234,7 @@ def attention_ref(
window_size: (int, int), left and right window size
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling
k
, etc.)
reorder_ops: whether to change the order of operations (scaling k instead of scaling
q
, etc.)
without changing the math. This is to estimate the numerical error from operation
without changing the math. This is to estimate the numerical error from operation
reordering.
reordering.
Output:
Output:
...
@@ -253,6 +254,10 @@ def attention_ref(
...
@@ -253,6 +254,10 @@ def attention_ref(
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
/
math
.
sqrt
(
d
),
k
)
else
:
else
:
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
scores
=
torch
.
einsum
(
"bthd,bshd->bhts"
,
q
,
k
/
math
.
sqrt
(
d
))
if
softcap
>
0
:
scores
/=
softcap
scores
=
scores
.
tanh
()
scores
*=
softcap
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
if
window_size
[
0
]
>=
0
or
window_size
[
1
]
>=
0
:
...
@@ -298,6 +303,7 @@ def attention_kvpacked_ref(
...
@@ -298,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -313,6 +319,7 @@ def attention_kvpacked_ref(
...
@@ -313,6 +319,7 @@ def attention_kvpacked_ref(
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -325,6 +332,7 @@ def attention_qkvpacked_ref(
...
@@ -325,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask
=
None
,
dropout_mask
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
):
):
...
@@ -340,6 +348,7 @@ def attention_qkvpacked_ref(
...
@@ -340,6 +348,7 @@ def attention_qkvpacked_ref(
upcast
=
upcast
,
upcast
=
upcast
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
reorder_ops
=
reorder_ops
,
reorder_ops
=
reorder_ops
,
)
)
...
@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked(
...
@@ -877,23 +886,29 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
def
test_flash_attn_output
(
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
):
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
batch_size
=
4
nheads
=
9
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
if
kvpacked
:
kv
=
torch
.
randn
(
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
...
@@ -918,6 +933,7 @@ def test_flash_attn_output(
...
@@ -918,6 +933,7 @@ def test_flash_attn_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -930,6 +946,7 @@ def test_flash_attn_output(
...
@@ -930,6 +946,7 @@ def test_flash_attn_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -984,6 +1001,7 @@ def test_flash_attn_output(
...
@@ -984,6 +1001,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
q
,
...
@@ -995,6 +1013,7 @@ def test_flash_attn_output(
...
@@ -995,6 +1013,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1010,6 +1029,7 @@ def test_flash_attn_output(
...
@@ -1010,6 +1029,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1022,6 +1042,7 @@ def test_flash_attn_output(
...
@@ -1022,6 +1042,7 @@ def test_flash_attn_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1036,7 +1057,7 @@ def test_flash_attn_output(
...
@@ -1036,7 +1057,7 @@ def test_flash_attn_output(
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
if
kvpacked
:
if
kvpacked
:
(
(
dq
,
dq
,
...
@@ -1092,7 +1113,7 @@ def test_flash_attn_output(
...
@@ -1092,7 +1113,7 @@ def test_flash_attn_output(
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.025
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -1133,24 +1154,31 @@ def test_flash_attn_output(
...
@@ -1133,24 +1154,31 @@ def test_flash_attn_output(
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
,
50.0
])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def
test_flash_attn_varlen_output
(
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
local
,
alibi
,
deterministic
,
mha_type
,
dtype
,
kvpacked
,
softcap
):
):
if
(
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
):
pytest
.
skip
()
# Reference implementation OOM
pytest
.
skip
()
# Reference implementation OOM
if
softcap
>
0.0
and
dropout_p
>
0.0
:
pytest
.
skip
(
"Softcap and dropout not supported together"
)
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
4
batch_size
=
4
nheads
=
9
nheads
=
6
if
softcap
==
0.0
else
4
# softcap reference impl takes more memory
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
3
)
nheads_k
=
nheads
if
mha_type
==
"mha"
else
(
1
if
mha_type
==
"mqa"
else
2
)
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
if
softcap
>
0
:
# Ensure the values of qk are at least within softcap range.
q
=
q
*
softcap
if
kvpacked
:
if
kvpacked
:
kv
=
torch
.
randn
(
kv
=
torch
.
randn
(
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
batch_size
,
seqlen_k
,
2
,
nheads_k
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
...
@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output(
...
@@ -1198,6 +1226,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output(
...
@@ -1229,6 +1258,7 @@ def test_flash_attn_varlen_output(
dropout_p
,
dropout_p
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
deterministic
,
deterministic
=
deterministic
,
return_attn_probs
=
True
,
return_attn_probs
=
True
,
...
@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output(
...
@@ -1288,6 +1318,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
out_pt
,
attn_pt
=
attention_kvpacked_ref
(
q
,
q
,
...
@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output(
...
@@ -1299,6 +1330,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output(
...
@@ -1314,6 +1346,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q
,
q
,
...
@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output(
...
@@ -1326,6 +1359,7 @@ def test_flash_attn_varlen_output(
dropout_mask
,
dropout_mask
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
)
)
...
@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output(
...
@@ -1339,7 +1373,7 @@ def test_flash_attn_varlen_output(
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
if
kvpacked
:
if
kvpacked
:
(
(
dq_unpad
,
dq_unpad
,
...
@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output(
...
@@ -1396,9 +1430,9 @@ def test_flash_attn_varlen_output(
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if
not
alibi
:
if
not
alibi
:
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.0
25
)
assert
abs
(
dropout_fraction
-
dropout_p
)
<=
(
0.01
if
not
local
else
0.0
4
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
):
if
(
(
d
<=
MAX_HEADDIM_SM8x
or
(
d
>
224
and
dropout_p
==
0
))
or
(
is_sm80
or
is_sm90
)
)
and
softcap
==
0.0
:
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
3
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
3
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
3
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
...
@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache(
...
@@ -1917,9 +1951,11 @@ def test_flash_attn_kvcache(
cache_seqlens
=
torch
.
randint
(
cache_seqlens
=
torch
.
randint
(
0
if
new_kv
else
1
,
0
if
new_kv
else
1
,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
(
seqlen_k
-
(
seqlen_q
if
(
causal
or
local
)
and
rotary_dim
>
1
else
seqlen_new
)
+
1
)
if
new_kv
if
new_kv
else
(
seqlen_k
+
1
),
else
(
seqlen_k
+
1
)
),
(
batch_size
,),
(
batch_size
,),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
device
=
device
,
...
@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
...
@@ -2455,12 +2491,12 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
if
(
d
<=
MAX_HEADDIM_SM8x
or
d
>
224
)
or
(
is_sm80
or
is_sm90
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
dq
0
,
dk
0
,
dv
0
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
for
_
in
range
(
50
):
for
_
in
range
(
50
):
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
dq
,
dk
,
dv
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
,
retain_graph
=
True
)
assert
torch
.
equal
(
dv
,
dv
)
assert
torch
.
equal
(
dv
,
dv
0
)
assert
torch
.
equal
(
dk
,
dk
)
assert
torch
.
equal
(
dk
,
dk
0
)
assert
torch
.
equal
(
dq
,
dq
)
assert
torch
.
equal
(
dq
,
dq
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
...
...
training/Dockerfile
View file @
d562aa63
...
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
...
@@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.
5.7
RUN
pip
install
flash-attn
==
2.
6.0
# Install CUDA extensions for fused dense
# Install CUDA extensions for fused dense
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.
5.7
#subdirectory
=
csrc/fused_dense_lib
RUN
pip
install
git+https://github.com/HazyResearch/flash-attention@v2.
6.0
#subdirectory
=
csrc/fused_dense_lib
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment