Unverified Commit c961ce92 authored by Anthony Chang's avatar Anthony Chang Committed by GitHub
Browse files

Hotfix LDS data hazard in fused attention (#360)

* avoid LDS data hazard in gemm_softmax_gemm pipeline

* trivial refactors

* comments

* shrink blockwise gemm v2 thread buffer size

* reclaim A block lds space when during 2nd gemm

* amend

* amend
parent 53ea4713
...@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -701,9 +701,7 @@ struct BlockwiseGemmXdlops_v2
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1]; const auto waveId_n = wave_idx[I1];
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto blk_idx =
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
...@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2 ...@@ -922,13 +920,13 @@ struct BlockwiseGemmXdlops_v2
} }
protected: protected:
// A[M0, M1, M2, KPerThread] // A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
// B[N0, N1, N2, KPerThread] // B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPack>{}));
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
......
...@@ -181,36 +181,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -181,36 +181,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); SharedMemTrait::b_block_space_size_aligned) *
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); sizeof(FloatAB);
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
// lds max alignment sizeof(FloatAB);
constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end);
constexpr auto b0_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -312,6 +292,36 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -312,6 +292,36 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -358,9 +368,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -358,9 +368,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -464,14 +471,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -464,14 +471,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
...@@ -588,7 +593,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -588,7 +593,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// reuse LDS space for gemm0's b_block_buf // reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( constexpr index_t Gemm1KPack = math::max(
...@@ -611,10 +616,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -611,10 +616,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
MXdlPerWave, MXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
Gemm1KPack, Gemm1KPack,
false, false, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
make_tuple(0, 0, 0, 0)}; // TransposeC // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
...@@ -699,6 +705,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -699,6 +705,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
block_sync_lds(); block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf); gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
......
...@@ -182,11 +182,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -182,11 +182,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
return math::max((SharedMemTrait::a_block_space_size_aligned + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) * SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatAB) + sizeof(FloatAB);
SharedMemTrait::reduction_workspace * sizeof(FloatGemmAcc), const index_t gemm1_bytes_end =
SharedMemTrait::c_block_size * sizeof(FloatCShuffle)); (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatAB);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -302,22 +310,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -302,22 +310,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// B1 can reuse B's LDS static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_size_aligned = static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value); static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction // LDS allocation for reduction
static constexpr index_t reduction_workspace = BlockSize; static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_size = static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
}; };
...@@ -471,10 +482,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -471,10 +482,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
...@@ -591,7 +603,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -591,7 +603,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// reuse LDS space for gemm0's b_block_buf // reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( constexpr index_t Gemm1KPack = math::max(
...@@ -617,7 +629,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -617,7 +629,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, // TransposeC true, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
make_tuple(0, 0, 0, 0)}; // TransposeC // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
...@@ -625,10 +638,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -625,10 +638,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// Blockwise softmax // Blockwise softmax
// //
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::a_block_space_size_aligned * sizeof(FloatAB) / 4 + SharedMemTrait::reduction_space_size_aligned);
SharedMemTrait::b_block_space_size_aligned * sizeof(FloatAB) / 4,
SharedMemTrait::reduction_workspace);
// get acc0 8D thread cluster // get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
...@@ -717,7 +728,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -717,7 +728,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum; mathext::exp(max - running_max_new) * sum;
block_sync_lds();
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
...@@ -736,12 +746,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -736,12 +746,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step); b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body // main body
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_k_block_inner_loop > 1)
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0), make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
...@@ -749,6 +760,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -749,6 +760,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds(); block_sync_lds();
...@@ -773,6 +785,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -773,6 +785,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
block_sync_lds(); block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
...@@ -817,6 +830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -817,6 +830,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = running_max_new; running_max = running_max_new;
running_sum = running_sum_new; running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle C and write out // shuffle C and write out
......
...@@ -819,7 +819,7 @@ struct XdlopsGemm ...@@ -819,7 +819,7 @@ struct XdlopsGemm
index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
return CIndex{m_offset, n_offset}; return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
} }
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{}; static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment