Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
cb516f85
Unverified
Commit
cb516f85
authored
Jul 22, 2024
by
Cameron Shinn
Committed by
GitHub
Jul 22, 2024
Browse files
Remove torchlib dependency from cpp files (#1083)
parent
5f1ae4a3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
26 deletions
+25
-26
hopper/flash.h
hopper/flash.h
+0
-11
hopper/flash_bwd_launch_template.h
hopper/flash_bwd_launch_template.h
+8
-9
hopper/flash_fwd_launch_template.h
hopper/flash_fwd_launch_template.h
+4
-5
hopper/utils.h
hopper/utils.h
+13
-1
No files found.
hopper/flash.h
View file @
cb516f85
...
@@ -7,14 +7,6 @@
...
@@ -7,14 +7,6 @@
#include <cuda.h>
#include <cuda.h>
#include <vector>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
#include "cutlass/fast_math.h" // For cutlass::FastDivmod
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -118,9 +110,6 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -118,9 +110,6 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size
// Local window size
int
window_size_left
,
window_size_right
;
int
window_size_left
,
window_size_right
;
// Random state.
at
::
PhiloxCudaState
philox_args
;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t
*
rng_state
;
uint64_t
*
rng_state
;
...
...
hopper/flash_bwd_launch_template.h
View file @
cb516f85
...
@@ -4,8 +4,6 @@
...
@@ -4,8 +4,6 @@
#pragma once
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cluster_launch.hpp"
...
@@ -15,6 +13,7 @@
...
@@ -15,6 +13,7 @@
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
#include "flash_bwd_kernel.h"
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
template
<
bool
Clear_dQaccum
=
true
,
typename
Kernel_traits
>
template
<
bool
Clear_dQaccum
=
true
,
typename
Kernel_traits
>
__global__
void
flash_bwd_dot_do_o_kernel
(
const
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_dot_do_o_kernel
(
const
Flash_bwd_params
params
)
{
...
@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -38,7 +37,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
flash_bwd_dot_do_o_kernel
<
true
,
Kernel_traits
><<<
grid_m
,
Kernel_traits
::
kNThreadsNonWS
,
0
,
stream
>>>
(
params
);
flash_bwd_dot_do_o_kernel
<
true
,
Kernel_traits
><<<
grid_m
,
Kernel_traits
::
kNThreadsNonWS
,
0
,
stream
>>>
(
params
);
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
// If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
// flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreadsNonWS, 0, stream>>>(params);
C
10
_CUDA_KERNEL_LAUNCH
_CHECK
();
C
HECK
_CUDA_KERNEL_LAUNCH
();
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
...
@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -157,7 +156,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
// printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C
10
_CUDA
_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
C
HECK
_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
static
constexpr
int
ctaSize
=
Kernel_traits
::
kNWarps
*
32
;
static
constexpr
int
ctaSize
=
Kernel_traits
::
kNWarps
*
32
;
...
@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -179,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
}
}
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
// cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO,
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
// tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV);
C
10
_CUDA_KERNEL_LAUNCH
_CHECK
();
C
HECK
_CUDA_KERNEL_LAUNCH
();
auto
tma_load_dQaccum
=
make_tma_copy
(
auto
tma_load_dQaccum
=
make_tma_copy
(
typename
cute
::
SM90_TMA_LOAD
{},
typename
cute
::
SM90_TMA_LOAD
{},
...
@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -190,20 +189,20 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
auto
kernel_dq
=
&
flash
::
convert_dQ
<
Kernel_traits
,
decltype
(
tma_load_dQaccum
)
>
;
auto
kernel_dq
=
&
flash
::
convert_dQ
<
Kernel_traits
,
decltype
(
tma_load_dQaccum
)
>
;
if
(
Kernel_traits
::
kSmemdQSize
*
2
+
8
>=
48
*
1024
)
{
if
(
Kernel_traits
::
kSmemdQSize
*
2
+
8
>=
48
*
1024
)
{
C
10
_CUDA
_CHECK
(
cudaFuncSetAttribute
(
C
HECK
_CUDA
(
cudaFuncSetAttribute
(
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
*
2
+
8
));
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
*
2
+
8
));
}
}
kernel_dq
<<<
grid_m
,
Kernel_traits
::
kNThreadsdQ
,
Kernel_traits
::
kSmemdQSize
*
2
+
8
,
stream
>>>
(
params
,
tma_load_dQaccum
);
kernel_dq
<<<
grid_m
,
Kernel_traits
::
kNThreadsdQ
,
Kernel_traits
::
kSmemdQSize
*
2
+
8
,
stream
>>>
(
params
,
tma_load_dQaccum
);
C
10
_CUDA_KERNEL_LAUNCH
_CHECK
();
C
HECK
_CUDA_KERNEL_LAUNCH
();
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
// auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
// if (Kernel_traits::kSmemdKVSize >= 48 * 1024) {
// C
10
_CUDA
_CHECK
(cudaFuncSetAttribute(
// C
HECK
_CUDA(cudaFuncSetAttribute(
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
// kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize));
// }
// }
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
// int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN);
// dim3 grid_n(num_n_block, params.b, params.h);
// dim3 grid_n(num_n_block, params.b, params.h);
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
// kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemdKVSize, stream>>>(params);
// C
10
_CUDA_KERNEL_LAUNCH
_CHECK
();
// C
HECK
_CUDA_KERNEL_LAUNCH();
}
}
...
...
hopper/flash_fwd_launch_template.h
View file @
cb516f85
...
@@ -4,8 +4,6 @@
...
@@ -4,8 +4,6 @@
#pragma once
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "cute/tensor.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
...
@@ -16,6 +14,7 @@
...
@@ -16,6 +14,7 @@
#include "tile_scheduler.hpp"
#include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
template
<
typename
Kernel_traits
,
bool
Is_causal
>
template
<
typename
Kernel_traits
,
bool
Is_causal
>
...
@@ -66,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -66,7 +65,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C
10
_CUDA
_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
C
HECK
_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
int
device
;
int
device
;
...
@@ -75,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -75,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
cudaError
status_
=
cudaDeviceGetAttribute
(
cudaError
status_
=
cudaDeviceGetAttribute
(
&
multiprocessor_count
,
cudaDevAttrMultiProcessorCount
,
device
);
&
multiprocessor_count
,
cudaDevAttrMultiProcessorCount
,
device
);
if
(
status_
!=
cudaSuccess
)
{
if
(
status_
!=
cudaSuccess
)
{
C
10
_CUDA
_CHECK
(
status_
);
C
HECK
_CUDA
(
status_
);
}
}
dim3
grid_dims
=
Scheduler
::
get_grid_dim
(
scheduler_args
,
multiprocessor_count
);
dim3
grid_dims
=
Scheduler
::
get_grid_dim
(
scheduler_args
,
multiprocessor_count
);
static
constexpr
int
ctaSize
=
Kernel_traits
::
kNWarps
*
32
;
static
constexpr
int
ctaSize
=
Kernel_traits
::
kNWarps
*
32
;
...
@@ -83,7 +82,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -83,7 +82,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
dim3
cluster_dims
(
size
<
0
>
(
ClusterShape
{}),
size
<
1
>
(
ClusterShape
{}),
size
<
2
>
(
ClusterShape
{}));
dim3
cluster_dims
(
size
<
0
>
(
ClusterShape
{}),
size
<
1
>
(
ClusterShape
{}),
size
<
2
>
(
ClusterShape
{}));
cutlass
::
ClusterLaunchParams
launch_params
{
grid_dims
,
block_dims
,
cluster_dims
,
smem_size
,
stream
};
cutlass
::
ClusterLaunchParams
launch_params
{
grid_dims
,
block_dims
,
cluster_dims
,
smem_size
,
stream
};
cutlass
::
launch_kernel_on_cluster
(
launch_params
,
kernel
,
mainloop_params
,
epilogue_params
,
scheduler_params
);
cutlass
::
launch_kernel_on_cluster
(
launch_params
,
kernel
,
mainloop_params
,
epilogue_params
,
scheduler_params
);
C
10
_CUDA_KERNEL_LAUNCH
_CHECK
();
C
HECK
_CUDA_KERNEL_LAUNCH
();
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
hopper/utils.h
View file @
cb516f85
...
@@ -21,6 +21,18 @@
...
@@ -21,6 +21,18 @@
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_types.h>
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
namespace
flash
{
namespace
flash
{
using
namespace
cute
;
using
namespace
cute
;
...
@@ -62,7 +74,7 @@ struct Allreduce {
...
@@ -62,7 +74,7 @@ struct Allreduce {
template
<
>
template
<
>
struct
Allreduce
<
2
>
{
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
template
<
typename
T
,
typename
Operator
>
static
__device__
__forceinline__
T
run
(
T
x
,
Operator
&
op
)
{
static
__device__
__forceinline__
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
return
x
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment