Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
bb9beb36
Commit
bb9beb36
authored
Sep 12, 2023
by
Tri Dao
Browse files
Remove some unused headers
parent
08c295c0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
6 additions
and
59 deletions
+6
-59
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+3
-1
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+1
-2
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+0
-3
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+0
-49
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+0
-1
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+1
-2
setup.py
setup.py
+1
-1
No files found.
csrc/flash_attn/flash_api.cpp
View file @
bb9beb36
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
* Copyright (c) 2023, Tri Dao.
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
******************************************************************************/
#include <torch/extension.h>
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
...
...
csrc/flash_attn/src/flash.h
View file @
bb9beb36
...
@@ -13,8 +13,7 @@
...
@@ -13,8 +13,7 @@
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
H_DIM
=
1
;
constexpr
int
H_DIM
=
1
;
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
bb9beb36
...
@@ -5,18 +5,15 @@
...
@@ -5,18 +5,15 @@
#pragma once
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "block_info.h"
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "softmax.h"
#include "philox.cuh"
namespace
flash
{
namespace
flash
{
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
bb9beb36
...
@@ -4,20 +4,16 @@
...
@@ -4,20 +4,16 @@
#pragma once
#pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "block_info.h"
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "softmax.h"
#include "philox.cuh"
namespace
flash
{
namespace
flash
{
...
@@ -25,49 +21,6 @@ using namespace cute;
...
@@ -25,49 +21,6 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMA_M
,
class
...
Args
,
class
TiledMMA
>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_M
=
decltype
(
size
<
0
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
kNWarps
=
decltype
(
size
<
0
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_M
;
constexpr
int
MMAStride_M
=
MMA_M
*
AtomShape_M
;
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_M
>
,
Int
<
kNWarps
>>
,
Stride
<
_1
,
Int
<
MMAStride_M
>>
>
{},
make_layout
(
size
<
2
>
(
TileShape_MNK
{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutA_TV
(),
t
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMA_M
,
class
...
Args
,
class
TiledMMA
>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM
(
Copy_Atom
<
Args
...
>
const
&
copy_atom
,
TiledMMA
const
&
tiled_mma
)
{
using
TileShape_MNK
=
typename
TiledMMA
::
TiledShape_MNK
;
using
AtomShape_MNK
=
typename
TiledMMA
::
AtomShape_MNK
;
constexpr
int
AtomShape_M
=
decltype
(
size
<
0
>
(
AtomShape_MNK
{}))
::
value
;
constexpr
int
kNWarps
=
decltype
(
size
<
0
>
(
TileShape_MNK
{}))
::
value
/
AtomShape_M
;
constexpr
int
MMAStride_M
=
MMA_M
*
AtomShape_M
;
auto
t
=
make_tile
(
Layout
<
Shape
<
Int
<
AtomShape_M
>
,
Int
<
kNWarps
>>
,
Stride
<
_1
,
Int
<
MMAStride_M
>>
>
{},
// TODO: Shouldn't this be size<1>?
make_layout
(
size
<
2
>
(
TileShape_MNK
{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return
make_tiled_copy_impl
(
copy_atom
,
tiled_mma
.
get_layoutC_TV
(),
t
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
scores
,
Tensor1
&
scores_max
,
Tensor1
&
scores_sum
,
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
Tensor2
&
acc_o
,
float
softmax_scale_log2
)
{
...
@@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -256,7 +209,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
...
@@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -558,7 +510,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Partition sO to match the accumulator partitioning
// Partition sO to match the accumulator partitioning
auto
smem_tiled_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_tiled_copy_O
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomO
{},
tiled_mma
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_O
=
smem_tiled_copy_O
.
get_thread_slice
(
tidx
);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOrO
=
smem_thr_copy_O
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
bb9beb36
...
@@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -76,7 +76,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
...
...
csrc/flash_attn/src/softmax.h
View file @
bb9beb36
...
@@ -8,8 +8,7 @@
...
@@ -8,8 +8,7 @@
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/array.h>
#include "philox.cuh"
#include "philox.cuh"
#include "utils.h"
#include "utils.h"
...
...
setup.py
View file @
bb9beb36
...
@@ -189,7 +189,7 @@ if not SKIP_CUDA_BUILD:
...
@@ -189,7 +189,7 @@ if not SKIP_CUDA_BUILD:
"--expt-relaxed-constexpr"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
"--use_fast_math"
,
"--ptxas-options=-v"
,
#
"--ptxas-options=-v",
# "--ptxas-options=-O2",
# "--ptxas-options=-O2",
"-lineinfo"
"-lineinfo"
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment