Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f66603cb
Commit
f66603cb
authored
Jun 29, 2022
by
Tri Dao
Browse files
Support batch size > 64K by swapping grid.x and grid.y
parent
450b64fe
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
18 additions
and
18 deletions
+18
-18
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+2
-2
csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
.../flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
+1
-1
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+4
-4
csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
+1
-1
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+2
-2
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+1
-1
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+4
-4
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+1
-1
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+2
-2
No files found.
csrc/flash_attn/src/fmha/gmem_tile.h
View file @
f66603cb
...
...
@@ -456,9 +456,9 @@ struct Gmem_summary_stats {
:
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
)),
tidx_
(
tidx
)
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t
bidx
=
bidb
*
params
.
h
+
bidh
;
...
...
csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu
View file @
f66603cb
...
...
@@ -45,7 +45,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
dim3
grid
(
params
.
b
,
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
f66603cb
...
...
@@ -118,9 +118,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
@@ -729,9 +729,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) {
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
View file @
f66603cb
...
...
@@ -68,7 +68,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
return
;
}
dim3
grid
(
launch_params
.
params
.
h
,
launch_params
.
params
.
b
);
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
f66603cb
...
...
@@ -497,9 +497,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
inline
__device__
void
device_block_1xN_loop
(
const
Params
&
params
)
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
f66603cb
...
...
@@ -44,7 +44,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
dim3
grid
(
params
.
b
,
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
f66603cb
...
...
@@ -119,9 +119,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
@@ -683,9 +683,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
f66603cb
...
...
@@ -68,7 +68,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
return
;
}
dim3
grid
(
launch_params
.
params
.
h
,
launch_params
.
params
.
b
);
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
f66603cb
...
...
@@ -621,9 +621,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
inline
__device__
void
device_1xN_loop
(
const
Params
&
params
)
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
const
int
bidb
=
blockIdx
.
x
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment