Unverified Commit ac3ef99c authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1010 from ROCmSoftwarePlatform/mha-train-develop-bias-shfl

Add bias with shuffle for flash attention fwd
parents acea1753 e87ddb0e
......@@ -101,7 +101,7 @@ using DeviceGemmInstance =
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
4, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
......@@ -121,13 +121,13 @@ using DeviceGemmInstance =
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
8,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
DIM / 32,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
......
......@@ -120,7 +120,7 @@ using DeviceGemmInstance =
8,
8,
true,
1,
8,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
......
......@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool time_kernel = true;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
......
......@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool time_kernel = true;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
......
......@@ -35,7 +35,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename D0GridDescriptor_M0_N0_N1_N2_M1_N3,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
......@@ -60,8 +60,7 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t h_ratio,
......@@ -109,7 +108,7 @@ __global__ void
c1de_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
b1_grid_desc_bk0_n_bk1,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map,
......@@ -129,7 +128,7 @@ __global__ void
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = h_ratio;
......@@ -206,7 +205,6 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionInfer<NumDimG,
......@@ -537,9 +535,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
{
D0GridDesc_M_N d0_grid_desc_m_n_ = MakeD0GridDescriptor_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n_);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(d0_grid_desc_m_n_);
d0_grid_desc_g_m_n_ = MakeD0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
......@@ -592,8 +589,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -660,7 +656,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
......@@ -685,7 +681,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.h_ratio_,
......
......@@ -112,7 +112,7 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].block_2_ctile_map_,
......@@ -465,8 +465,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -576,9 +575,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
d0_grid_desc_m_n);
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(d0_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
......@@ -628,7 +626,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
p_c_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_grid_desc_m0_n0_m1_m2_n1_m3_,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......
......@@ -92,11 +92,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4");
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -366,6 +361,125 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
static constexpr auto D0N2 = AK1;
static constexpr auto D0N1 = AK0;
static constexpr auto D0N0 = Number<NPerBlock / KPerBlock>{};
static_assert(NPerBlock % KPerBlock == 0);
__host__ __device__ static constexpr auto
MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(const D0GridDesc_M_N& d0_grid_desc_m_n)
{
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto d0_grid_desc_m0_n0_n1_n2_m1_n3 = transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, D0N0, D0N1, D0N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3, 5>{}));
return d0_grid_desc_m0_n0_n1_n2_m1_n3;
}
using D0GridDescriptor_M0_N0_N1_N2_M1_N3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(D0GridDesc_M_N{}))>;
struct D0Operator
{
static_assert(ABlockTransferThreadClusterLengths_AK0_M_AK1::Size() == 3);
static_assert(ABlockTransferDstScalarPerVector_AK1 % D0BlockTransferSrcScalarPerVector ==
0);
template <typename DataType>
struct TypeTransform
{
using Type = DataType;
};
template <>
struct TypeTransform<void>
{
using Type = ck::half_t;
};
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3()
{
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0N1, Number<MPerBlock>{}, D0N2));
}
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2()
{
constexpr auto d0_raw_n0_m_n1 =
make_naive_tensor_descriptor_packed(make_tuple(D0N1, Number<MPerBlock>{}, D0N2));
constexpr auto d0_raw_m_n = transform_tensor_descriptor(
d0_raw_n0_m_n1,
make_tuple(make_pass_through_transform(Number<MPerBlock>{}),
make_merge_transform(make_tuple(D0N1, D0N2))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto d0_m0_m1_n0_n1_n2 = transform_tensor_descriptor(
d0_raw_m_n,
make_tuple(make_unmerge_transform(
make_tuple(Number<MPerBlock / MPerXdl>{}, Number<MPerXdl>{})),
make_unmerge_transform(make_tuple((D0N1 * D0N2) / (I2 * I4), I2, I4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}));
return d0_m0_m1_n0_n1_n2;
}
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, I4));
static constexpr auto d0_block_dst_desc_m0_n0_n1_n2_m1_n3 =
GetD0BlockGlobalDescriptor_M0_N0_N1_N2_M1_N3();
static constexpr auto d0_block_src_desc_m0_m1_n0_n1_n2 =
GetD0BlockVgprDescriptor_M0_M1_N0_N1_N2();
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0N1, MPerBlock, D0N2>,
typename sequence_merge<Sequence<1, 1, 1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1>::type,
Sequence<0, 1, 2, 4, 3, 5>,
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_N1_N2_M1_N3,
decltype(d0_block_dst_desc_m0_n0_n1_n2_m1_n3),
Sequence<0, 1, 2, 4, 3, 5>,
Sequence<0, 1, 2, 4, 3, 5>,
5,
5,
D0BlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_src_desc_m0_m1_n0_n1_n2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
4, // SrcScalarPerVector
2>;
};
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
......@@ -403,7 +517,8 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
};
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const D0DataType* __restrict__ p_d0_grid,
const FloatAB* __restrict__ p_b1_grid,
......@@ -416,8 +531,7 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const D0GridDescriptor_M0_N0_N1_N2_M1_N3& d0_grid_desc_m0_n0_n1_n2_m1_n3,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -683,49 +797,6 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// bias (d0 matrix)
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // RegisterNum
auto d0_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<D0DataType,
D0DataType,
decltype(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9,
D0BlockTransferSrcScalarPerVector,
1,
false>(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0)); // register number
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
......@@ -827,6 +898,17 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// gemm1 K loop
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(block_work_idx[I0], 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
index_t gemm1_k_block_outer_index = 0;
do
{
......@@ -920,30 +1002,53 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// get register
StaticBuffer<AddressSpaceEnum::Vgpr,
D0DataType,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
d0_thread_buf;
// load data from global
d0_threadwise_copy.Run(d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
d0_grid_buf,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
p_d0_grid, d0_grid_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0N0, 1>{}([&](auto nr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_n1_n2_m1_n3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_n1_n2_m1_n3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_dst_desc_m0_n0_n1_n2_m1_n3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_m0_m1_n0_n1_n2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// acc add bias
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
acc_thread_buf(i) += ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(I0, nr, i));
acc_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
d0_threadwise_copy.MoveSrcSliceWindow(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_n1_n2_m1_n3,
make_multi_index(0, 1, -D0N0.value, 0, 0, 0));
}
}
// softmax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment