Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e78e7c95
Commit
e78e7c95
authored
Jun 02, 2022
by
Tri Dao
Browse files
Remove old backward
parent
512c98ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
938 deletions
+0
-938
csrc/flash_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+0
-83
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
+0
-855
No files found.
csrc/flash_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
deleted
100644 → 0
View file @
512c98ee
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
// #include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_recompute.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
// extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_dv_kernel(Fused_multihead_attention_fprop_params params) {
// fmha::compute_dv_1xN<Kernel_traits>(params);
// }
// extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_dq_dk_kernel(Fused_multihead_attention_fprop_params params) {
// fmha::compute_dq_dk_1xN<Kernel_traits>(params);
// }
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dp_dq_1xN
<
Kernel_traits
>
(
params
);
}
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_dv_dk_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
512
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
// constexpr int smem_size_dp_dq = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
// constexpr int smem_size_dv_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr
int
smem_size_dp_dq
=
smem_size_q
*
2
+
smem_size_q
+
smem_size_v
+
smem_size_o
;
constexpr
int
smem_size_dv_dk
=
smem_size_q
+
smem_size_q
+
smem_size_v
+
smem_size_o
+
smem_size_s
;
if
(
smem_size_dp_dq
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
// fmha_dgrad_fp16_512_64_sm80_dv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dp_dq));
fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dp_dq
));
}
if
(
smem_size_dv_dk
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_512_64_sm80_dv_dk_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dv_dk
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
// fmha_dgrad_fp16_512_64_sm80_dv_kernel<<<grid, Kernel_traits::THREADS, smem_size_dp_dq, stream>>>(params);
// fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel<<<grid, Kernel_traits::THREADS, smem_size_dp_dq, stream>>>(params);
fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dp_dq
,
stream
>>>
(
params
);
fmha_dgrad_fp16_512_64_sm80_dv_dk_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dv_dk
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
deleted
100644 → 0
View file @
512c98ee
/* Copyright (c) 2022, Tri Dao.
*/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_M
>
inline
__device__
void
dot_fragments
(
float
(
&
sum
)[
MMAS_M
*
2
],
const
fmha
::
Fragment_a
<
fmha
::
Row
>
(
&
x
)[
MMAS_M
],
const
fmha
::
Fragment_a
<
fmha
::
Row
>
(
&
y
)[
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
sum
[
mi
*
2
+
0
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
0
),
y
[
mi
].
template
elt_as
<
__half2
>(
0
));
sum
[
mi
*
2
+
0
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
2
),
y
[
mi
].
template
elt_as
<
__half2
>(
2
));
sum
[
mi
*
2
+
1
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
1
),
y
[
mi
].
template
elt_as
<
__half2
>(
1
));
sum
[
mi
*
2
+
1
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
3
),
y
[
mi
].
template
elt_as
<
__half2
>(
3
));
// hfma2_to_float(sum[mi * 2 + 0], x[mi].template elt_as<__half2>(0), y[mi].template elt_as<__half2>(0));
// hfma2_to_float(sum[mi * 2 + 0], x[mi].template elt_as<__half2>(2), y[mi].template elt_as<__half2>(2));
// hfma2_to_float(sum[mi * 2 + 1], x[mi].template elt_as<__half2>(1), y[mi].template elt_as<__half2>(1));
// hfma2_to_float(sum[mi * 2 + 1], x[mi].template elt_as<__half2>(3), y[mi].template elt_as<__half2>(3));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dp_dq_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dq
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dq
=
fmha
::
Hmma_tile
<
Cta_tile_dq
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K^T. Treat K^T as V
using
Smem_tile_kt
=
typename
Kernel_traits
::
Smem_tile_v
;
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load dO.
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_do
;
// The shared memory tile to load dO.
// Treating dO as Q.
using
Smem_tile_do
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load O.Loading O here is similar to loading dO.
using
Gmem_tile_o
=
Gmem_tile_do
;
// The shared memory tile to load O.
using
Smem_tile_o
=
Smem_tile_do
;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using
Gmem_tile_dq
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dq
>
;
// The shared memory tile to swizzle dQ.
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
>
mask
(
binfo
,
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O.
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
// Trigger the loads for K.
gmem_k
.
load
();
// Trigger the loads for Q.
gmem_q
.
load
();
// Trigger the loads for V.
gmem_v
.
load
();
// Trigger the loads for dO.
gmem_do
.
load
();
// Trigger the loads for O.
gmem_o
.
load
();
const
uint32_t
scale_bmm1
=
reinterpret_cast
<
const
uint32_t
&>
(
params
.
scale_bmm1
);
#pragma unroll
for
(
int
it
=
0
;
it
<
Gmem_tile_k
::
LDGS
;
it
++
){
gmem_k
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_bmm1
,
gmem_k
.
fetch_
[
it
]);
}
// Commit the data for Q, dO, and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
gmem_o
.
commit
(
smem_o
);
gmem_v
.
commit
(
smem_v
);
// Commit the data for K to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for Q.
gemm_q_k
.
load_q
();
// Load the fragments for dO.
typename
Smem_tile_do
::
Fragment
frag_do
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_do
.
load
(
frag_do
[
0
],
0
);
// Load the fragments for O.
typename
Smem_tile_o
::
Fragment
frag_o
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_o
.
load
(
frag_o
[
0
],
0
);
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for K.
gemm_q_k
.
load_k
();
// Load the fragments for K^T.
typename
Smem_tile_kt
::
Fragment
frag_kt
[
2
][
Mma_tile_dq
::
MMAS_N
];
smem_kt
.
load
(
frag_kt
[
0
],
0
);
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// smem_kt.load(frag_kt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for this softmax at all
// Softmax softmax(params, &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], bidb, tidx);
Softmax
softmax
(
params
,
smem_
,
tidx
);
// Softmax softmax_dp(params, &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], bidb, tidx);
Gmem_softmax_sum
gmem_softmax_sum
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
constexpr
int
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
;
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
float
p_lse
[
Mma_tile_p
::
MMAS_M
*
2
];
gmem_softmax_sum
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_sum
.
move
();
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_p
);
// Trigger the load for the next Q values.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
();
}
// Load the mask for that iteration.
mask
.
load
(
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
// Scale by log-sum-exp of the softmax
softmax
.
template
apply_exp
<
/*max_in_base2=*/
true
>(
p_lse
);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// __syncthreads();
// }
float
dp_sum_new
[
Mma_tile_p
::
MMAS_M
*
2
]
=
{
0
};
dot_fragments
(
dp_sum_new
,
frag_do
[
0
],
frag_o
[
0
]);
fmha
::
Fragment_accumulator
acc_dp
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_dp
);
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of dO values.
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
smem_o
.
load
(
frag_o
[
ki
&
1
],
ki
);
dot_fragments
(
dp_sum_new
,
frag_do
[
ki
&
1
],
frag_o
[
ki
&
1
]);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("dp_sum_new=%.6f, %.6f\n", dp_sum_new[0], dp_sum_new[1]);
// }
// smem_v.load(frag_v[ki & 1], ki);
// Do the math for the values already in registers.
// fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
// if ((threadIdx.x == 1) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki - 1]));
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
// }
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
// fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
// if ((threadIdx.x == 1) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_dp=%.6f, %.6f\n", acc_dp[0][0].elt(0), acc_dp[0][0].elt(1));
// }
// Trigger the load for the next dO values.
if
(
l
<
STEPS
-
1
)
{
smem_do
.
move_to_next_write_buffer
();
gmem_do
.
move
();
gmem_do
.
load
();
smem_o
.
move_to_next_write_buffer
();
gmem_o
.
move
();
gmem_o
.
load
();
}
// softmax_dp.unpack_noscale(acc_dp);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax_dp.elt_[mi][ni] *= softmax.elt_[mi][ni];
// }
// }
// float dp_sum[Mma_tile_p::MMAS_M * 2];
// softmax_dp.reduce_sum(dp_sum);
// gmem_softmax_d.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
// gmem_softmax_d.move();
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax_dp.elt_[mi][ni] -= dp_sum[mi] * softmax.elt_[mi][ni];
// }
// }
fmha
::
SumOp
<
float
>
sum_op
;
fmha
::
quad_allreduce
(
dp_sum_new
,
dp_sum_new
,
sum_op
);
// softmax_dp.unpack_noscale(acc_dp);
softmax
.
unpack_noscale
(
acc_dp
);
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// // softmax_dp.elt_[mi][ni] -= dp_sum_new[mi];
// softmax.elt_[mi][ni] -= dp_sum_new[mi];
// }
// }
softmax
.
subtract_dp_sum
(
dp_sum_new
);
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
// softmax_dp.pack(frag_dp);
softmax
.
pack
(
frag_dp
);
gmem_softmax_d
.
store
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum_new
));
gmem_softmax_d
.
move
();
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
frag_p
[
mi
][
ni
].
hmul
(
frag_dp
[
mi
][
ni
]);
}
}
// softmax_dp.pack(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
gmem_o
.
commit
(
smem_o
);
}
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dq
[
Mma_tile_dq
::
MMAS_M
][
Mma_tile_dq
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dq
::
WARPS_K
>::
apply
(
acc_dq
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dq
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_dq
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_dq
.
store
(
acc_dq
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_dq
::
STGS_PER_LOOP
];
smem_dq
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_dq
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_dq
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_dq
.
move
();
gemm_q_k
.
reload_k
();
smem_kt
.
load
(
frag_kt
[
0
],
0
);
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
reload_q
();
smem_do
.
move_to_next_read_buffer
();
smem_do
.
load
(
frag_do
[
0
],
0
);
smem_o
.
move_to_next_read_buffer
();
smem_o
.
load
(
frag_o
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dv_dk_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dkv
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dk
=
fmha
::
Hmma_tile
<
Cta_tile_dkv
>
;
// The global memory tile to load Q. Treating Q as K.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The global memory tile to load K. Treating K as Q.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q^T. Treat Q^T as V
using
Smem_tile_qt
=
typename
Kernel_traits
::
Smem_tile_v
;
// Treating V as dO.
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_q
;
// Treating dO as V in dQ kernel, which is the same as K in the forward kernel.
// The global memory tile to load dO.
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_dot
;
// The shared memory tile to load dO.
using
Smem_tile_do
=
typename
Kernel_traits
::
Smem_tile_k
;
// The shared memory tile to load dO^T.
using
Smem_tile_dot
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store dK and dV.
// using Gmem_tile_dkv = typename Kernel_traits::Gmem_tile_dkv;
using
Gmem_tile_dkv
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dkv
>
;
// The shared memory tile to swizzle dK and dV.
using
Smem_tile_dkv
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dK.
Gmem_tile_dkv
gmem_dk
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dV.
Gmem_tile_dkv
gmem_dv
(
params
,
2
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
>
mask
(
binfo
,
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
Smem_tile_v
smem_v
(
&
smem_
[
0
],
tidx
);
// Allocate the shared memory tile loader for Q^T. We use the same as Q so be careful!!!
Smem_tile_qt
smem_qt
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
// The base pointer of smem_do;
char
*
smem_do_
=
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_do
smem_do
(
smem_do_
,
tidx
);
Smem_tile_dot
smem_dot
(
smem_do_
,
tidx
);
// Allocate the shared memory tile loader for dK and dV. We use the same as K so be careful!!!
Smem_tile_dkv
smem_dkv
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
();
// Trigger the loads for K.
gmem_k
.
load
();
// Trigger the loads for dO.
gmem_do
.
load
();
// Trigger the loads for V.
gmem_v
.
load
();
const
uint32_t
scale_bmm1
=
reinterpret_cast
<
const
uint32_t
&>
(
params
.
scale_bmm1
);
#pragma unroll
for
(
int
it
=
0
;
it
<
Gmem_tile_q
::
LDGS
;
it
++
){
gmem_q
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_bmm1
,
gmem_q
.
fetch_
[
it
]);
}
// Commit the data for K, dO, and V to shared memory.
gmem_k
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
gmem_do
.
commit
(
smem_do
);
// Commit the data for Q to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for K.
gemm_q_k
.
load_q
();
// Load the fragments for V.
typename
Smem_tile_v
::
Fragment
frag_v
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_v
.
load
(
frag_v
[
0
],
0
);
// Load the fragments for dO. We keep the data in registers during the entire kernel.
typename
Smem_tile_do
::
Fragment
frag_do
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
smem_do
.
load
(
frag_do
[
ki
],
ki
);
}
using
Smem_tile_mma_t
=
fmha
::
Smem_tile_transpose
<
Cta_tile_p
>
;
// Smem_tile_mma_t smem_mmat(&smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
Smem_tile_mma_t
smem_mmat
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dkv
::
BYTES_PER_TILE
],
tidx
);
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_q
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for Q.
gemm_q_k
.
load_k
();
// Load the fragments for K^T.
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dk
::
MMAS_N
];
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// typename Smem_tile_qt::Fragment frag_qt[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// smem_qt.load(frag_qt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for either of the softmax at all
Softmax
softmax
(
params
,
smem_
,
tidx
);
Softmax
softmax_dp
(
params
,
smem_
,
tidx
);
Gmem_softmax_sum
gmem_softmax_sum
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
int
warp
=
tidx
/
Cta_tile_p
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile_p
::
THREADS_PER_WARP
;
int
rows
[
Mma_tile_p
::
MMAS_N
*
4
];
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
rows
[
ni
*
4
+
0
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
;
rows
[
ni
*
4
+
1
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
+
1
;
rows
[
ni
*
4
+
2
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
+
8
;
rows
[
ni
*
4
+
3
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
+
9
;
}
float
p_lse
[
Mma_tile_p
::
MMAS_N
*
4
];
gmem_softmax_sum
.
load_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_N
*
4
]
>
(
p_lse
),
rows
);
float
dp_sum
[
Mma_tile_p
::
MMAS_N
*
4
];
gmem_softmax_d
.
load_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_N
*
4
]
>
(
dp_sum
),
rows
);
// int qid = lane / 8;
// int rows_shfl[Mma_tile_p::MMAS_N];
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
// rows_shfl[ni] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2 + (qid / 2) * 8 + (qid % 2);
// }
// float p_lse[Mma_tile_p::MMAS_N];
// gmem_softmax_sum.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N]>(p_lse), rows_shfl);
// float dp_sum[Mma_tile_p::MMAS_N];
// gmem_softmax_d.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N]>(dp_sum), rows_shfl);
constexpr
int
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
;
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
typename
Smem_tile_dot
::
Fragment
frag_dot
[
2
][
Mma_tile_p
::
MMAS_N
];
// smem_mmat.store(frag_do, 0);
// smem_mmat.load(frag_dot[0]);
// smem_mmat.transpose(frag_do, frag_dot[0], 0);
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_pt
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_pt
);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_pt
);
// Trigger the load for the next K values.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_k
.
move
();
gmem_k
.
load
();
}
// Load the mask for that iteration.
mask
.
load
(
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_pt
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
// Scale by log-sum-exp of the softmax
softmax
.
template
apply_exp_col
<
/*max_in_base2=*/
true
>(
p_lse
);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dv
);
smem_mmat
.
transpose
(
frag_do
,
frag_dot
[
0
],
0
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// fmha::gemm(acc_dv, frag_p[ki], frag_dot[ki]);
if
(
ki
+
1
<
Mma_tile_dk
::
MMAS_K
)
{
// smem_mmat.store(frag_do, ki + 1);
// smem_mmat.load(frag_dot[(ki + 1) % 2]);
smem_mmat
.
transpose
(
frag_do
,
frag_dot
[(
ki
+
1
)
%
2
],
ki
+
1
);
}
fmha
::
gemm
(
acc_dv
,
frag_p
[
ki
],
frag_dot
[
ki
%
2
]);
}
__syncthreads
();
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_dkv
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_dkv
.
store
(
acc_dv
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_dkv
::
STGS_PER_LOOP
];
smem_dkv
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_dkv
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_dv
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_dv
.
move
();
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
fmha
::
Fragment_accumulator
acc_dpt
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_dpt
);
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of dO values.
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dpt
,
frag_v
[(
ki
-
1
)
&
1
],
frag_do
[
ki
-
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
// fmha::gemm(acc_dpt, frag_v[(ki - 1) & 1], frag_do[(ki - 1) & 1]);
fmha
::
gemm
(
acc_dpt
,
frag_v
[(
ki
-
1
)
&
1
],
frag_do
[(
ki
-
1
)]);
}
// Trigger the load for the next V values.
if
(
l
<
STEPS
-
1
)
{
smem_v
.
move_to_next_write_buffer
();
gmem_v
.
move
();
gmem_v
.
load
();
}
softmax_dp
.
unpack_noscale
(
acc_dpt
);
// TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// will be zero.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
*
4
;
ni
++
)
{
// softmax.elt_[mi][ni] *= (softmax_dp.elt_[mi][ni] - dp_sum[ni]);
softmax_dp
.
elt_
[
mi
][
ni
]
-=
dp_sum
[
ni
];
// const float tmp = __shfl_sync(0xffffffff, dp_sum[ni / 4], (ni % 4) * 8 + threadIdx.x % 8);
// softmax_dp.elt_[mi][ni] -= tmp;
}
}
Frag_p
frag_dp
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
softmax_dp
.
pack
(
frag_dp
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
frag_p
[
mi
][
ni
].
hmul
(
frag_dp
[
mi
][
ni
]);
}
}
// using Frag_p = fmha::Fragment_a<fmha::Row>;
// Frag_p frag_p[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
// softmax.pack(frag_p);
// softmax_dp.pack(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
__syncthreads
();
// Commit the values for K and V into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
}
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dk
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_p
[
ki
-
1
],
frag_qt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_p
[
ki
-
1
],
frag_qt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1)]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_dkv
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_dkv
.
store
(
acc_dk
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_dkv
::
STGS_PER_LOOP
];
smem_dkv
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_dkv
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_dk
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_dk
.
move
();
gemm_q_k
.
reload_k
();
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
reload_q
();
smem_v
.
move_to_next_read_buffer
();
smem_v
.
load
(
frag_v
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment