Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e518a4b3
Commit
e518a4b3
authored
Jul 08, 2022
by
Tri Dao
Browse files
Refactor to template on __half, implement bf16 util functions
parent
2dc1b205
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
127 additions
and
115 deletions
+127
-115
csrc/flash_attn/src/fmha/gemm.h
csrc/flash_attn/src/fmha/gemm.h
+2
-1
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+5
-32
csrc/flash_attn/src/fmha/smem_tile.h
csrc/flash_attn/src/fmha/smem_tile.h
+5
-5
csrc/flash_attn/src/fmha/softmax.h
csrc/flash_attn/src/fmha/softmax.h
+5
-23
csrc/flash_attn/src/fmha/utils.h
csrc/flash_attn/src/fmha/utils.h
+84
-25
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+7
-7
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+3
-4
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+13
-14
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+3
-4
No files found.
csrc/flash_attn/src/fmha/gemm.h
View file @
e518a4b3
...
@@ -142,10 +142,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
...
@@ -142,10 +142,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
}
}
}
}
template
<
typename
elem_type
>
inline
__device__
void
hrelu_
()
{
inline
__device__
void
hrelu_
()
{
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
this
->
reg
(
ii
)
=
fmha
::
hrelu2
(
this
->
reg
(
ii
));
this
->
reg
(
ii
)
=
fmha
::
hrelu2
<
elem_type
>
(
this
->
reg
(
ii
));
}
}
}
}
};
};
...
...
csrc/flash_attn/src/fmha/gmem_tile.h
View file @
e518a4b3
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
#pragma once
#pragma once
#include <cuda_fp16.h>
namespace
fmha
{
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -219,6 +221,7 @@ struct Gmem_tile_o {
...
@@ -219,6 +221,7 @@ struct Gmem_tile_o {
}
}
// Store data to global memory.
// Store data to global memory.
template
<
typename
elem_type
=
__half
>
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
int
row_
=
tidx_
/
THREADS_PER_ROW
;
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
#pragma unroll
...
@@ -237,7 +240,7 @@ struct Gmem_tile_o {
...
@@ -237,7 +240,7 @@ struct Gmem_tile_o {
float
y
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
y
);
float
y
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
y
);
float
z
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
z
);
float
z
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
z
);
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
uint2
out
=
f
loat4_to_half4
(
x
,
y
,
z
,
w
);
uint2
out
=
f
mha
::
float4_pack
<
elem_type
>
(
x
,
y
,
z
,
w
);
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_stride_in_bytes
,
out
);
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_stride_in_bytes
,
out
);
}
}
...
@@ -245,7 +248,7 @@ struct Gmem_tile_o {
...
@@ -245,7 +248,7 @@ struct Gmem_tile_o {
}
}
}
}
//
Store
data
to
global memory.
//
Load
data
from
global memory.
inline
__device__
void
load
(
uint4
(
&
dst
)[
STGS_PER_LOOP
],
int
mi
)
{
inline
__device__
void
load
(
uint4
(
&
dst
)[
STGS_PER_LOOP
],
int
mi
)
{
static_assert
(
BYTES_PER_ELEMENT
==
4
);
static_assert
(
BYTES_PER_ELEMENT
==
4
);
int
row_
=
tidx_
/
THREADS_PER_ROW
;
int
row_
=
tidx_
/
THREADS_PER_ROW
;
...
@@ -366,36 +369,6 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -366,36 +369,6 @@ struct Gmem_tile_mma_s : public Base {
:
Base
(
params
.
s_ptr
,
params
,
binfo
.
bidb
,
binfo
.
bidh
,
tidx
)
{
:
Base
(
params
.
s_ptr
,
params
,
binfo
.
bidb
,
binfo
.
bidh
,
tidx
)
{
}
}
// Store to global memory.
template
<
typename
Mask
>
inline
__device__
void
store
(
const
float
(
&
softmax
)[
2
*
M
][
4
*
N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
float
tmp00
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp01
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp02
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp03
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp10
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp11
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp12
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp13
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
3
];
uint4
dst
;
dst
.
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dst
.
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dst
.
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dst
.
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
if
(
mask
.
is_valid
(
mi
,
ni
,
0
,
0
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
}
}
}
}
// Store to global memory.
// Store to global memory.
template
<
typename
Mask
,
typename
Fragment
>
template
<
typename
Mask
,
typename
Fragment
>
inline
__device__
void
store
(
const
Fragment
(
&
frag
)[
N
][
M
],
const
Mask
&
mask
){
inline
__device__
void
store
(
const
Fragment
(
&
frag
)[
N
][
M
],
const
Mask
&
mask
){
...
...
csrc/flash_attn/src/fmha/smem_tile.h
View file @
e518a4b3
...
@@ -1384,7 +1384,7 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1384,7 +1384,7 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
}
}
template
<
int
M
,
int
N
>
template
<
typename
elem_type
=
__half
,
int
M
,
int
N
>
inline
__device__
void
store
(
const
Acc
(
&
acc
)[
M
][
N
]){
inline
__device__
void
store
(
const
Acc
(
&
acc
)[
M
][
N
]){
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
...
@@ -1401,10 +1401,10 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1401,10 +1401,10 @@ struct Smem_tile_mma_epilogue : public Base {
float
tmp12
=
acc
[
mi
][
ni
].
elt
(
6
);
float
tmp12
=
acc
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc
[
mi
][
ni
].
elt
(
7
);
float
tmp13
=
acc
[
mi
][
ni
].
elt
(
7
);
uint32_t
x
=
fmha
::
float2_
to_half2
(
tmp00
,
tmp01
);
uint32_t
x
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp00
,
tmp01
);
uint32_t
y
=
fmha
::
float2_
to_half2
(
tmp02
,
tmp03
);
uint32_t
y
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp02
,
tmp03
);
uint32_t
z
=
fmha
::
float2_
to_half2
(
tmp10
,
tmp11
);
uint32_t
z
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp10
,
tmp11
);
uint32_t
w
=
fmha
::
float2_
to_half2
(
tmp12
,
tmp13
);
uint32_t
w
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp12
,
tmp13
);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
...
...
csrc/flash_attn/src/fmha/softmax.h
View file @
e518a4b3
...
@@ -34,24 +34,6 @@ namespace fmha {
...
@@ -34,24 +34,6 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Sum_
{
static
constexpr
bool
IS_SUM
=
true
;
static
inline
__device__
float
apply
(
float
x
,
float
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Max_
{
static
constexpr
bool
IS_SUM
=
false
;
static
inline
__device__
float
apply
(
float
x
,
float
y
)
{
return
x
>
y
?
x
:
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
apply_exp_
(
float
x
,
float
max
)
{
inline
__device__
float
apply_exp_
(
float
x
,
float
max
)
{
return
__expf
(
x
-
max
);
return
__expf
(
x
-
max
);
}
}
...
@@ -508,7 +490,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -508,7 +490,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
// Pack the data to a fragment for the next GEMM.
// Pack the data to a fragment for the next GEMM.
template
<
int
K
,
int
M
>
template
<
typename
elem_type
=
__half
,
int
K
,
int
M
>
inline
__device__
void
pack
(
Fragment_a
(
&
dst
)[
K
][
M
])
const
{
inline
__device__
void
pack
(
Fragment_a
(
&
dst
)[
K
][
M
])
const
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
...
@@ -528,10 +510,10 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -528,10 +510,10 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
3
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
3
];
// Pack to 4 registers.
// Pack to 4 registers.
dst
[
ki
][
mi
].
reg
(
0
)
=
fmha
::
float2_
to_half2
(
tmp_00
,
tmp_01
);
dst
[
ki
][
mi
].
reg
(
0
)
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp_00
,
tmp_01
);
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_
to_half2
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_
to_half2
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_
to_half2
(
tmp_12
,
tmp_13
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_
pack
<
elem_type
>
(
tmp_12
,
tmp_13
);
}
}
}
}
}
}
...
...
csrc/flash_attn/src/fmha/utils.h
View file @
e518a4b3
...
@@ -33,6 +33,10 @@
...
@@ -33,6 +33,10 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
extern
"C"
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
ptr
);
extern
"C"
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
ptr
);
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -310,12 +314,16 @@ static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
...
@@ -310,12 +314,16 @@ static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hrelu2
(
uint32_t
x
,
uint32_t
lb
=
0
)
{
template
<
typename
T
>
inline
__device__
uint32_t
hrelu2
(
uint32_t
x
);
template
<
>
inline
__device__
uint32_t
hrelu2
<
__half
>
(
uint32_t
x
)
{
uint32_t
res
;
uint32_t
res
;
const
uint32_t
zero
=
0u
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"max.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
lb
));
asm
volatile
(
"max.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#else
#else
const
uint32_t
zero
=
0u
;
asm
volatile
(
\
asm
volatile
(
\
"{
\n
"
\
"{
\n
"
\
"
\t
.reg .f16x2 sela;
\n
"
\
"
\t
.reg .f16x2 sela;
\n
"
\
...
@@ -325,6 +333,19 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
...
@@ -325,6 +333,19 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
#endif
#endif
return
res
;
return
res
;
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
uint32_t
hrelu2
<
__nv_bfloat16
>
(
uint32_t
x
)
{
uint32_t
res
;
const
uint32_t
zero
=
0u
;
asm
volatile
(
"max.bf16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
return
res
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
uint32_t
res
;
uint32_t
res
;
asm
volatile
(
"abs.f16x2 %0, %1;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
));
asm
volatile
(
"abs.f16x2 %0, %1;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
));
...
@@ -332,7 +353,7 @@ static inline __device__ uint32_t habs2(uint32_t x) {
...
@@ -332,7 +353,7 @@ static inline __device__ uint32_t habs2(uint32_t x) {
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template
<
typename
T
>
template
<
typename
T
>
static
inline
__device__
T
clamp
(
T
x
,
T
lb
,
T
ub
)
{
static
inline
__device__
T
clamp
(
T
x
,
T
lb
,
T
ub
)
{
return
x
<
lb
?
lb
:
(
x
>
ub
?
ub
:
x
);
return
x
<
lb
?
lb
:
(
x
>
ub
?
ub
:
x
);
...
@@ -370,6 +391,25 @@ static inline __device__ uint32_t float2_to_half2(float a, float b) {
...
@@ -370,6 +391,25 @@ static inline __device__ uint32_t float2_to_half2(float a, float b) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
uint32_t
float2_pack
(
float
a
,
float
b
);
template
<
>
inline
__device__
uint32_t
float2_pack
<
__half
>
(
float
a
,
float
b
)
{
__half2
result
=
__floats2half2_rn
(
a
,
b
);
return
reinterpret_cast
<
uint32_t
(
&
)
>
(
result
);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
uint32_t
float2_pack
<
__nv_bfloat16
>
(
float
a
,
float
b
)
{
__nv_bfloat162
result
=
__floats2bfloat162_rn
(
a
,
b
);
return
reinterpret_cast
<
uint32_t
(
&
)
>
(
result
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float_to_half2
(
float
a
)
{
static
inline
__device__
uint32_t
float_to_half2
(
float
a
)
{
return
float2_to_half2
(
a
,
a
);
return
float2_to_half2
(
a
,
a
);
}
}
...
@@ -391,6 +431,16 @@ static inline __device__ uint2 float4_to_half4(float x, float y, float z, float
...
@@ -391,6 +431,16 @@ static inline __device__ uint2 float4_to_half4(float x, float y, float z, float
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
uint2
float4_pack
(
float
x
,
float
y
,
float
z
,
float
w
)
{
uint2
d
;
d
.
x
=
float2_pack
<
T
>
(
x
,
y
);
d
.
y
=
float2_pack
<
T
>
(
z
,
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
static
inline
__device__
uint32_t
hfma2
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
...
@@ -404,7 +454,7 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
...
@@ -404,7 +454,7 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"fma.rn.f16x2.relu %0, %1, %2, %3;"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
asm
volatile
(
"fma.rn.f16x2.relu %0, %1, %2, %3;"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
#else
#else
d
=
hrelu2
(
hfma2
(
a
,
b
,
c
));
d
=
hrelu2
<
__half
>
(
hfma2
(
a
,
b
,
c
));
#endif
#endif
return
d
;
return
d
;
}
}
...
@@ -481,32 +531,41 @@ static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
...
@@ -481,32 +531,41 @@ static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's into float, then take their dot product.
template
<
typename
T
>
// inline __device__ void hfma2_to_float(float &sum, const __half2 a, const __half2 b) {
inline
__device__
float2
half2_unpack
(
uint32_t
a
);
static
inline
__device__
float
hfma2_to_float
(
const
__half2
a
,
const
__half2
b
)
{
float2
af
=
__half22float2
(
a
);
template
<
>
float2
bf
=
__half22float2
(
b
);
inline
__device__
float2
half2_unpack
<
__half
>
(
uint32_t
a
)
{
return
__half22float2
(
reinterpret_cast
<
__half2
(
&
)
>
(
a
));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
float2
half2_unpack
<
__nv_bfloat16
>
(
uint32_t
a
)
{
return
__bfloat1622float2
(
reinterpret_cast
<
__nv_bfloat162
(
&
)
>
(
a
));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's or bf162's into float, then take their dot product.
template
<
typename
T
>
inline
__device__
float
hfma2_to_float
(
const
uint32_t
a
,
const
uint32_t
b
)
{
float2
af
=
fmha
::
half2_unpack
<
T
>
(
a
);
float2
bf
=
fmha
::
half2_unpack
<
T
>
(
b
);
return
af
.
x
*
bf
.
x
+
af
.
y
*
bf
.
y
;
return
af
.
x
*
bf
.
x
+
af
.
y
*
bf
.
y
;
// sum += af.x * bf.x + af.y * bf.y;
// sum = __fmaf_rn(sum, af.x, bf.x);
// sum = __fmaf_rn(sum, af.y, bf.y);
// float2 prod = __half22float2(__hmul2(a, b));
// sum += prod.x + prod.y;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's into float, then take their dot product.
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
static
inline
__device__
float
hmulsum8
(
const
uint4
a
,
const
uint4
b
)
{
template
<
typename
T
>
inline
__device__
float
hmulsum8
(
const
uint4
a
,
const
uint4
b
)
{
float
sum
;
float
sum
;
sum
=
fmha
::
hfma2_to_float
(
reinterpret_cast
<
const
__half2
&>
(
a
.
x
),
sum
=
fmha
::
hfma2_to_float
<
T
>
(
a
.
x
,
b
.
x
);
reinterpret_cast
<
const
__half2
&>
(
b
.
x
));
sum
+=
fmha
::
hfma2_to_float
<
T
>
(
a
.
y
,
b
.
y
);
sum
+=
fmha
::
hfma2_to_float
(
reinterpret_cast
<
const
__half2
&>
(
a
.
y
),
sum
+=
fmha
::
hfma2_to_float
<
T
>
(
a
.
z
,
b
.
z
);
reinterpret_cast
<
const
__half2
&>
(
b
.
y
));
sum
+=
fmha
::
hfma2_to_float
<
T
>
(
a
.
w
,
b
.
w
);
sum
+=
fmha
::
hfma2_to_float
(
reinterpret_cast
<
const
__half2
&>
(
a
.
z
),
reinterpret_cast
<
const
__half2
&>
(
b
.
z
));
sum
+=
fmha
::
hfma2_to_float
(
reinterpret_cast
<
const
__half2
&>
(
a
.
w
),
reinterpret_cast
<
const
__half2
&>
(
b
.
w
));
return
sum
;
return
sum
;
}
}
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
e518a4b3
...
@@ -18,7 +18,7 @@ inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const ui
...
@@ -18,7 +18,7 @@ inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const ui
Smem_dp_sum
smem
,
const
int
buffer_idx
)
{
Smem_dp_sum
smem
,
const
int
buffer_idx
)
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
sum
[
mi
]
=
smem
.
reduce_warp
(
fmha
::
hmulsum8
(
do_
[
mi
],
o
[
mi
]));
sum
[
mi
]
=
smem
.
reduce_warp
(
fmha
::
hmulsum8
<
__half
>
(
do_
[
mi
],
o
[
mi
]));
}
}
static_assert
(
M
==
1
);
static_assert
(
M
==
1
);
smem
.
store
(
sum
[
0
],
buffer_idx
);
smem
.
store
(
sum
[
0
],
buffer_idx
);
...
@@ -358,7 +358,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -358,7 +358,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
softmax
.
template
pack
<
__half
>
(
frag_p
);
// Store s * dmask to smem for transpose
// Store s * dmask to smem for transpose
smem_s
.
store
(
frag_p
);
smem_s
.
store
(
frag_p
);
...
@@ -463,7 +463,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -463,7 +463,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
if
(
is_first_read
)
{
softmax
.
subtract_dp_sum
(
dp_sum
);
}
if
(
is_first_read
)
{
softmax
.
subtract_dp_sum
(
dp_sum
);
}
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
softmax
.
pack
(
frag_dp
);
softmax
.
template
pack
<
__half
>
(
frag_dp
);
if
(
!
Is_dropout
)
{
if
(
!
Is_dropout
)
{
#pragma unroll
#pragma unroll
...
@@ -544,7 +544,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -544,7 +544,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
for
(
int
ki
=
0
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
ki
++
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
ki
++
)
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
frag_s
[
ki
][
mi
].
hrelu_
();
frag_s
[
ki
][
mi
].
template
hrelu_
<
__half
>
();
}
}
}
}
}
}
...
@@ -638,7 +638,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -638,7 +638,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// }
// }
dq_out
[
0
]
=
fmha
::
fmul4
(
dq_out
[
0
],
params
.
scale_bmm1f
);
dq_out
[
0
]
=
fmha
::
fmul4
(
dq_out
[
0
],
params
.
scale_bmm1f
);
// Output the values.
// Output the values.
gmem_dq
.
store
(
dq_out
,
0
);
gmem_dq
.
template
store
<
__half
>
(
dq_out
,
0
);
}
else
{
}
else
{
// Output the values.
// Output the values.
gmem_dq_tmp
.
store
(
dq_out
,
0
);
gmem_dq_tmp
.
store
(
dq_out
,
0
);
...
@@ -693,11 +693,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -693,11 +693,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// the total amount of shared mem?
// the total amount of shared mem?
// Epilogue swizzle for dV
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
0
],
tidx
);
Smem_tile_dv
smem_dv
(
&
smem_
[
0
],
tidx
);
smem_dv
.
store
(
acc_dv
);
smem_dv
.
template
store
<
__half
>
(
acc_dv
);
// Epilogue swizzle for dK
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
Smem_tile_dv
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_dk
smem_dk
(
&
smem_
[
Smem_tile_dv
::
BYTES_PER_TILE
],
tidx
);
smem_dk
.
store
(
acc_dk
);
smem_dk
.
template
store
<
__half
>
(
acc_dk
);
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
e518a4b3
...
@@ -335,7 +335,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -335,7 +335,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
static_assert
(
Mma_tile_o
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_o
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_o
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
static_assert
(
Mma_tile_o
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
softmax
.
template
pack
<
__half
>
(
frag_p
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
store
(
frag_p
,
mask
);
if
(
not_last_iter
)
{
if
(
not_last_iter
)
{
...
@@ -353,7 +353,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -353,7 +353,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
frag_p
[
ki
][
mi
].
hrelu_
();
frag_p
[
ki
][
mi
].
template
hrelu_
<
__half
>
();
}
}
}
}
}
}
...
@@ -371,7 +371,6 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -371,7 +371,6 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// So we recalculate the max.
// So we recalculate the max.
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
// TODO: not sure if this is right for seqlen 128 or 256
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
rows
[
jj
]
=
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
+
jj
*
Gmem_tile_o
::
ROWS_PER_STG
;
rows
[
jj
]
=
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
+
jj
*
Gmem_tile_o
::
ROWS_PER_STG
;
...
@@ -467,7 +466,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -467,7 +466,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
// Output the values.
// Output the values.
if
(
is_final_write
)
{
if
(
is_final_write
)
{
gmem_o
.
store
(
out
,
0
);
gmem_o
.
template
store
<
__half
>
(
out
,
0
);
}
else
{
}
else
{
gmem_o_tmp
.
store
(
out
,
0
);
gmem_o_tmp
.
store
(
out
,
0
);
}
}
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
e518a4b3
...
@@ -12,14 +12,16 @@ namespace fmha {
...
@@ -12,14 +12,16 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
ROWS
,
int
THREADS_PER_ROW
,
int
M
,
typename
Gmem_softmax_sum
>
template
<
int
ROWS
,
int
THREADS_PER_ROW
,
typename
elem_type
=
__half
,
int
M
,
typename
Gmem_softmax_sum
>
inline
__device__
void
dot_do_o
(
const
uint4
(
&
do_
)[
M
],
const
uint4
(
&
o
)[
M
],
const
float
scale
,
inline
__device__
void
dot_do_o
(
const
uint4
(
&
do_
)[
M
],
const
uint4
(
&
o
)[
M
],
const
float
scale
,
Gmem_softmax_sum
gmem_softmax_d
,
int
tidx
)
{
Gmem_softmax_sum
gmem_softmax_d
,
int
tidx
)
{
float
sum
[
M
];
float
sum
[
M
];
fmha
::
SumOp
<
float
>
sum_op
;
fmha
::
SumOp
<
float
>
sum_op
;
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
sum
[
mi
]
=
fmha
::
Allreduce
<
THREADS_PER_ROW
>::
run
(
fmha
::
hmulsum8
(
do_
[
mi
],
o
[
mi
]),
sum_op
)
*
scale
;
sum
[
mi
]
=
fmha
::
Allreduce
<
THREADS_PER_ROW
>::
run
(
fmha
::
hmulsum8
<
elem_type
>
(
do_
[
mi
],
o
[
mi
]),
sum_op
)
*
scale
;
}
}
const
int
dp_sum_row
=
tidx
/
THREADS_PER_ROW
;
const
int
dp_sum_row
=
tidx
/
THREADS_PER_ROW
;
if
((
dp_sum_row
<
ROWS
)
&&
(
tidx
%
THREADS_PER_ROW
==
0
))
{
if
((
dp_sum_row
<
ROWS
)
&&
(
tidx
%
THREADS_PER_ROW
==
0
))
{
...
@@ -212,7 +214,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -212,7 +214,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
if
(
Is_first
)
{
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
>
(
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
__half
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
);
);
}
}
...
@@ -331,7 +333,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -331,7 +333,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
softmax
.
template
pack
<
__half
>
(
frag_p
);
// Store s * dmask to smem for transpose
// Store s * dmask to smem for transpose
smem_s
.
store
(
frag_p
);
smem_s
.
store
(
frag_p
);
...
@@ -422,7 +424,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -422,7 +424,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
}
}
}
}
softmax
.
pack
(
frag_p
);
softmax
.
template
pack
<
__half
>
(
frag_p
);
// Store dp to smem for transpose
// Store dp to smem for transpose
smem_dp
.
store
(
frag_p
);
smem_dp
.
store
(
frag_p
);
...
@@ -473,7 +475,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -473,7 +475,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
for
(
int
ki
=
0
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
ki
++
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
ki
++
)
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
frag_s
[
ki
][
mi
].
hrelu_
();
frag_s
[
ki
][
mi
].
template
hrelu_
<
__half
>
();
}
}
}
}
}
}
...
@@ -517,7 +519,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -517,7 +519,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
if
(
l
<
steps
-
1
)
{
if
(
l
<
steps
-
1
)
{
gmem_do
.
commit
(
smem_do
);
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
if
(
Is_first
)
{
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
>
(
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
__half
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
);
);
}
}
...
@@ -573,7 +575,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -573,7 +575,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
dq_out
[
jj
]
=
fmha
::
fmul4
(
dq_out
[
jj
],
params
.
scale_bmm1_rp_dropout
);
dq_out
[
jj
]
=
fmha
::
fmul4
(
dq_out
[
jj
],
params
.
scale_bmm1_rp_dropout
);
}
}
// Output the values.
// Output the values.
gmem_dq
.
store
(
dq_out
,
0
);
gmem_dq
.
template
store
<
__half
>
(
dq_out
,
0
);
// Move to the next part of the output.
// Move to the next part of the output.
gmem_dq
.
move
();
gmem_dq
.
move
();
}
else
{
}
else
{
...
@@ -627,11 +629,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -627,11 +629,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// the total amount of shared mem?
// the total amount of shared mem?
// Epilogue swizzle for dV
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
0
],
tidx
);
Smem_tile_dv
smem_dv
(
&
smem_
[
0
],
tidx
);
smem_dv
.
store
(
acc_dv
);
smem_dv
.
template
store
<
__half
>
(
acc_dv
);
// Epilogue swizzle for dK
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
Smem_tile_dv
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_dk
smem_dk
(
&
smem_
[
Smem_tile_dv
::
BYTES_PER_TILE
],
tidx
);
smem_dk
.
store
(
acc_dk
);
smem_dk
.
template
store
<
__half
>
(
acc_dk
);
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
...
@@ -644,9 +646,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -644,9 +646,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
smem_dk
.
load
(
dk_out
);
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
binfo
,
tidx
,
false
);
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
e518a4b3
...
@@ -466,7 +466,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -466,7 +466,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
static_assert
(
Mma_tile_o
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_o
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_o
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
static_assert
(
Mma_tile_o
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
softmax
.
template
pack
<
__half
>
(
frag_p
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
move
();
gmem_s
.
move
();
...
@@ -482,7 +482,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -482,7 +482,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
frag_p
[
ki
][
mi
].
hrelu_
();
frag_p
[
ki
][
mi
].
template
hrelu_
<
__half
>
();
}
}
}
}
}
}
...
@@ -509,7 +509,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -509,7 +509,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// The mapping from tidx to rows changes between the softmax and the
// The mapping from tidx to rows changes between the softmax and the
// O-reduction. So we recalculate the max.
// O-reduction. So we recalculate the max.
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
// TODO: not sure if this is right for seqlen 128 or 256
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
rows
[
jj
]
=
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
+
jj
*
Gmem_tile_o
::
ROWS_PER_STG
;
rows
[
jj
]
=
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
+
jj
*
Gmem_tile_o
::
ROWS_PER_STG
;
...
@@ -606,7 +605,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -606,7 +605,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Output the values.
// Output the values.
if
(
is_final_write
)
{
if
(
is_final_write
)
{
gmem_o
.
store
(
out
,
0
);
gmem_o
.
template
store
<
__half
>
(
out
,
0
);
gmem_o
.
move
();
gmem_o
.
move
();
}
else
{
}
else
{
gmem_o_tmp
.
store
(
out
,
0
);
gmem_o_tmp
.
store
(
out
,
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment