Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
14dc326e
Commit
14dc326e
authored
Jun 02, 2022
by
Tri Dao
Browse files
Use Cutlass gemm as WarpMma
parent
e78e7c95
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
30 deletions
+80
-30
csrc/flash_attn/src/fmha/gemm.h
csrc/flash_attn/src/fmha/gemm.h
+50
-0
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+12
-12
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+1
-1
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+12
-12
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+5
-5
No files found.
csrc/flash_attn/src/fmha/gemm.h
View file @
14dc326e
...
...
@@ -29,6 +29,13 @@
#include <fmha/utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/layout/layout.h"
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -247,6 +254,49 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm_cl
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
using
Shape
=
cutlass
::
gemm
::
GemmShape
<
16
*
M
,
16
*
N
,
16
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
Element
=
cutlass
::
half_t
;
using
ElementC
=
float
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
WarpMma
=
typename
cutlass
::
gemm
::
warp
::
DefaultMmaTensorOp
<
Shape
,
InstructionShape
,
Element
,
LayoutA
,
Element
,
LayoutB
,
ElementC
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
arch
::
OpMultiplyAdd
,
1
,
true
>::
Type
;
using
FragmentA
=
typename
WarpMma
::
FragmentA
;
using
FragmentB
=
typename
WarpMma
::
FragmentB
;
using
FragmentC
=
typename
WarpMma
::
FragmentC
;
static_assert
(
FragmentA
::
kStorageElements
==
M
*
a
[
0
].
NUM_REGS
);
static_assert
(
FragmentB
::
kStorageElements
==
N
*
b
[
0
].
NUM_REGS
);
static_assert
(
FragmentC
::
kStorageElements
==
M
*
N
*
acc
[
0
][
0
].
NUM_REGS
);
const
FragmentA
a_cl
=
reinterpret_cast
<
const
FragmentA
(
&
)
>
(
a
);
const
FragmentB
b_cl
=
reinterpret_cast
<
const
FragmentB
(
&
)
>
(
b
);
FragmentC
c_cl
=
reinterpret_cast
<
FragmentC
(
&
)
>
(
acc
);
WarpMma
mma_op
;
mma_op
(
c_cl
,
a_cl
,
b_cl
,
c_cl
);
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
acc
[
mi
][
ni
].
elt
(
i
)
=
c_cl
[
mi
*
N
*
8
+
ni
*
8
+
i
];
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The number of rows in the CTA tile.
int
M_
,
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
14dc326e
...
...
@@ -408,9 +408,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...
...
@@ -424,9 +424,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
}
...
...
@@ -515,14 +515,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha
::
gemm
_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm
_cl
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha
::
gemm
_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm
_cl
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
...
...
@@ -555,13 +555,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// Trigger the load from shared memory for the next series of Q values.
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// __syncthreads();
...
...
@@ -613,13 +613,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Make sure dQ is in shared memory.
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
14dc326e
...
...
@@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm
_cl
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
14dc326e
...
...
@@ -383,9 +383,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...
...
@@ -399,9 +399,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
fmha
::
gemm
_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
}
...
...
@@ -484,14 +484,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha
::
gemm
_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm
_cl
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha
::
gemm
_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm
_cl
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
...
...
@@ -524,13 +524,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// __syncthreads();
...
...
@@ -579,13 +579,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Make sure dQ is in shared memory.
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
14dc326e
...
...
@@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm
_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm
_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
...
...
@@ -175,12 +175,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
...
...
@@ -497,7 +497,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm
_cl
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment