Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
871db479
Commit
871db479
authored
Oct 21, 2022
by
Tri Dao
Browse files
Don't need to run configure for the forward pass
parent
7fc39832
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
42 deletions
+30
-42
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+3
-5
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+1
-1
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+18
-31
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+8
-5
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
871db479
...
@@ -294,7 +294,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -294,7 +294,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
is_causal
,
is_causal
,
num_splits
);
num_splits
);
run_fmha_fp16_sm80
(
launch_params
,
/*configure=*/
true
);
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
...
@@ -307,7 +306,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -307,7 +306,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
launch_params
.
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
launch_params
.
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
}
run_fmha_fp16_sm80
(
launch_params
,
/*configure=*/
false
);
run_fmha_fp16_sm80
(
launch_params
);
std
::
vector
<
at
::
Tensor
>
result
=
{
softmax_lse
};
std
::
vector
<
at
::
Tensor
>
result
=
{
softmax_lse
};
if
(
return_softmax
)
{
result
.
push_back
(
s
);}
if
(
return_softmax
)
{
result
.
push_back
(
s
);}
...
@@ -453,9 +452,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -453,9 +452,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
// We're gonna reset the rng state in Python after this kernel, so the counter offset
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// here doesn't matter at all. We just choose an arbitrary number.
int64_t
counter_offset
=
params
.
b
*
params
.
h
*
32
;
int64_t
counter_offset
=
4
;
if
(
is_dropout
)
{
if
(
is_dropout
)
{
// See Note [Acquire lock when using random generators]
// See Note [Acquire lock when using random generators]
...
...
csrc/flash_attn/src/fmha.h
View file @
871db479
...
@@ -191,7 +191,7 @@ struct Launch_params{
...
@@ -191,7 +191,7 @@ struct Launch_params{
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
);
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
871db479
...
@@ -65,22 +65,10 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
...
@@ -65,22 +65,10 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
void
run_fmha_fp16_sm80_loop_
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
void
run_fmha_fp16_sm80_loop_
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
const
bool
configure
)
{
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
const
int
loop_steps
=
(
launch_params
.
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
if
(
configure
)
{
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
size_t
STEPS
=
(
launch_params
.
params
.
seqlen_q
+
M
-
1
)
/
M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
*
loop_steps
;
launch_params
.
elts_per_thread
=
elts_per_head
;
return
;
}
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
// Don't need smem_size_softmax_lse if we're not looping
// Don't need smem_size_softmax_lse if we're not looping
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
...
@@ -123,38 +111,37 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -123,38 +111,37 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
});
});
}
}
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
const
bool
configure
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
==
256
)
{
}
else
if
(
launch_params
.
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
{
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
}
else
if
(
launch_params
.
params
.
d
==
64
)
{
}
else
if
(
launch_params
.
params
.
d
==
64
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
...
@@ -166,30 +153,30 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -166,30 +153,30 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
// For causal=True, block size 128 seems always faster (for small & large batch size).
// For causal=True, block size 128 seems always faster (for small & large batch size).
// So we're just gonna use block size 128 for simplicity.
// So we're just gonna use block size 128 for simplicity.
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
// if (launch_params.params.d == 64) {
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// if (launch_params.params.d == 64) {
// if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if( launch_params.params.seqlen_k >= 256 ) {
// } else if( launch_params.params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor >= 0) {
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
// }
// }
// }
...
@@ -197,16 +184,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -197,16 +184,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
// if (launch_params.params.d == 128) {
// if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// } else {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // Need to use the same block size as backward
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params
, configure
);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
// }
// }
// }
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
871db479
...
@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
...
@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
int
steps
,
int
step_stride
,
Prng
&
ph
,
const
int
loop_step_idx
)
{
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
int
steps
,
Prng
&
ph
,
const
int
loop_step_idx
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
elem_type
=
typename
Kernel_traits
::
elem_type
;
using
elem_type
=
typename
Kernel_traits
::
elem_type
;
...
@@ -250,6 +250,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -250,6 +250,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// The thread index.
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
// How many steps to jump per iteration, which is the same as params.num_splits.
const
int
step_stride
=
gridDim
.
z
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
// if( binfo.stop_early() ) return;
// if( binfo.stop_early() ) return;
if
(
binfo
.
stop_early
(
loop_step_idx
*
Cta_tile_p
::
N
)
)
return
;
if
(
binfo
.
stop_early
(
loop_step_idx
*
Cta_tile_p
::
N
)
)
return
;
...
@@ -683,14 +686,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
...
@@ -683,14 +686,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
seqlen_k
==
blocksize_c
)
{
if
(
params
.
seqlen_k
==
blocksize_c
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
0
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph
,
0
);
}
else
{
}
else
{
const
int
max_loop_steps
=
(
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
const
int
max_loop_steps
=
(
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
0
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
loop_step_idx
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph
,
loop_step_idx
);
}
}
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
max_loop_steps
-
1
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
ph
,
max_loop_steps
-
1
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment