Commit 8dbb73b1 authored by aska-0096's avatar aska-0096
Browse files

format

parent cc6a534f
...@@ -45,56 +45,55 @@ using CDEElementOp = ck::tensor_operation::element_wise::Add; ...@@ -45,56 +45,55 @@ using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN = using DeviceOpInstanceKKNN =
ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle<NumDimG,
NumDimG, NumDimM,
NumDimM, NumDimN,
NumDimN, NumDimK,
NumDimK, ADataType,
ADataType, BDataType,
BDataType, DsDataType,
DsDataType, EDataType,
EDataType, AccDataType,
AccDataType, CShuffleDataType,
CShuffleDataType, AElementOp,
AElementOp, BElementOp,
BElementOp, CDEElementOp,
CDEElementOp, GemmSpec,
GemmSpec, ASpec,
ASpec, BSpec,
BSpec, DESpec,
DESpec, 256,
256, 128,
128, 128,
128, 4,
4, 8,
8, 16,
16, 16,
16, 4,
4, 2,
2, S<4, 64, 1>,
S<4, 64, 1>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, 2,
2, 8,
8, 8,
8, true,
true, S<4, 64, 1>,
S<4, 64, 1>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, 2,
2, 8,
8, 8,
8, true,
true, 1,
1, 1,
1, S<1, 32, 1, 8>,
S<1, 32, 1, 8>, 8>;
8>;
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
...@@ -327,7 +326,8 @@ int main(int argc, char* argv[]) ...@@ -327,7 +326,8 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
...@@ -379,7 +379,7 @@ int main(int argc, char* argv[]) ...@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
ck::index_t K = ck::accumulate_n<ck::index_t>( ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::cout<<"GMNK="<<G<<", "<<M<<", "<<N<<", "<<K<<std::endl; std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
std::size_t flop = std::size_t(2) * G * M * N * K; std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
......
...@@ -5,7 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,7 +5,9 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) if(GPU_TARGETS MATCHES "gfx1100")
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
...@@ -16,5 +18,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft ...@@ -16,5 +18,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_custom_target(example_gemm_scale_softmax_gemm_wmma) if(GPU_TARGETS MATCHES "gfx1100")
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16) add_custom_target(example_gemm_scale_softmax_gemm_wmma)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16)
endif()
...@@ -94,17 +94,17 @@ using DeviceGemmInstance = ...@@ -94,17 +94,17 @@ using DeviceGemmInstance =
TensorSpecB1, TensorSpecB1,
TensorSpecC, TensorSpecC,
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // LPerBlock 128, // LPerBlock
4, // K0PerBlock 4, // K0PerBlock
8, // K1 8, // K1
64, // NPerBlock 64, // NPerBlock
4, // L0PerBlock 4, // L0PerBlock
8, // L1 8, // L1
16, // MPerWMMA 16, // MPerWMMA
16, // LPerWMMA 16, // LPerWMMA
16, // NPerWMMA 16, // NPerWMMA
//Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, // MRepeat 1, // MRepeat
8, // LRepeat 8, // LRepeat
4, // NRepeat 4, // NRepeat
......
...@@ -218,7 +218,7 @@ int run(int argc, char* argv[]) ...@@ -218,7 +218,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
......
...@@ -33,10 +33,10 @@ template <index_t BlockSize, ...@@ -33,10 +33,10 @@ template <index_t BlockSize,
* B: K0PerBlock x NPerBlock x K1 * B: K0PerBlock x NPerBlock x K1
* Destination * Destination
* C, non-transpose * C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs * thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16 * KPACK == WMMA_K = 16
* *
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source: * Source:
* A(if skip LDS): MRepeat x KPack * A(if skip LDS): MRepeat x KPack
...@@ -62,7 +62,8 @@ struct BlockwiseGemmWMMA ...@@ -62,7 +62,8 @@ struct BlockwiseGemmWMMA
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{}; static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
...@@ -149,13 +150,8 @@ struct BlockwiseGemmWMMA ...@@ -149,13 +150,8 @@ struct BlockwiseGemmWMMA
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(Number<m0>{}, return make_tuple(
blk_idx[I0], Number<m0>{}, blk_idx[I0], waveId_m, Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
waveId_m,
Number<n0>{},
waveId_n,
blk_idx[I1],
blk_idx[I2]);
} }
using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
...@@ -169,7 +165,8 @@ struct BlockwiseGemmWMMA ...@@ -169,7 +165,8 @@ struct BlockwiseGemmWMMA
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0, static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
} }
...@@ -180,20 +177,15 @@ struct BlockwiseGemmWMMA ...@@ -180,20 +177,15 @@ struct BlockwiseGemmWMMA
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; // constexpr auto NSubGroup =
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; // c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; constexpr auto MThreadPerSubGroup
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; // = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave // |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs // |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
I1,
I1,
Number<NRepeat>{},
I1,
I1,
NAccVgprs));
} }
// Thread level, register decriptor. Vector-write // Thread level, register decriptor. Vector-write
......
...@@ -393,10 +393,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -393,10 +393,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using DsGridDesc_G_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_G_M_N({}, {}))>; using DsGridDesc_G_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_G_M_N({}, {}))>;
using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {}));
...@@ -604,10 +604,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -604,10 +604,12 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_n_k_ = b_grid_desc_n_k_ =
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
e_grid_desc_m_n_ =
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
...@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -619,8 +621,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_); ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock = e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
e_grid_desc_m_n_);
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
...@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -696,9 +697,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{ {
const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
......
...@@ -54,10 +54,10 @@ template <index_t NumDimG, ...@@ -54,10 +54,10 @@ template <index_t NumDimG,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t LPerBlock, ck::index_t LPerBlock,
ck::index_t K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim ck::index_t K0PerBlock, // K0 * K1 = Gemm0 GEMM_K Dim
ck::index_t K1, // ck::index_t K1, //
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t L0PerBlock, ck::index_t L0PerBlock,
ck::index_t L1, ck::index_t L1,
ck::index_t MPerWMMA, ck::index_t MPerWMMA,
ck::index_t LPerWMMA, ck::index_t LPerWMMA,
ck::index_t NPerWMMA, ck::index_t NPerWMMA,
...@@ -136,7 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -136,7 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>, Sequence<NumDimG, NumDimM, NumDimL, NumDimK, NumDimN>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>, Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
...@@ -261,7 +261,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -261,7 +261,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
K1, // K1, //
NPerBlock, NPerBlock,
L0PerBlock, L0PerBlock,
L1, L1,
MPerWMMA, MPerWMMA,
LPerWMMA, LPerWMMA,
NPerWMMA, NPerWMMA,
...@@ -339,10 +339,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -339,10 +339,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b0_grid_desc_bk0_l_bk1_{ b0_grid_desc_bk0_l_bk1_{DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(
DeviceOp::MakeB0GridDescriptor_BK0_L_BK1(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
b1_grid_desc_bl0_n_bl1_{ b1_grid_desc_bl0_n_bl1_{DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(
DeviceOp::MakeB1GridDescriptor_BL0_N_BL1(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
c_grid_desc_m_n_{ c_grid_desc_m_n_{
Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
...@@ -408,7 +408,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -408,7 +408,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_; B0GridDesc_BK0_L_BK1 b0_grid_desc_bk0_l_bk1_;
B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_; B1GridDesc_BL0_N_BL1 b1_grid_desc_bl0_n_bl1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
B0GridDesc_G_L_K b0_grid_desc_g_l_k_; B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
B1GridDesc_G_N_L b1_grid_desc_g_n_l_; B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
...@@ -450,9 +450,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -450,9 +450,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle< const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
...@@ -552,11 +554,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -552,11 +554,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_n = arg.c_grid_desc_m_n_.GetLength(I1); const index_t c_n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_n = arg.b1_grid_desc_bl0_n_bl1_.GetLength(I1); const index_t b1_n = arg.b1_grid_desc_bl0_n_bl1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_n == b1_n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_n == b1_n))
{ {
...@@ -592,8 +594,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -592,8 +594,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
const auto b1_stride_lowest = const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
const auto c_stride_lowest = const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
arg.c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1)) c_stride_lowest == 1))
...@@ -610,8 +611,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -610,8 +611,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto static auto MakeArgument(
MakeArgument(
const ADataType* p_a, const ADataType* p_a,
const B0DataType* p_b0, const B0DataType* p_b0,
const B1DataType* p_b1, const B1DataType* p_b1,
...@@ -634,7 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -634,7 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b0, p_b0,
...@@ -664,8 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -664,8 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument> MakeArgumentPointer(
MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b0, const void* p_b0,
const void* p_b1, const void* p_b1,
......
...@@ -135,10 +135,10 @@ template <typename FloatA, ...@@ -135,10 +135,10 @@ template <typename FloatA,
index_t MPerBlock, index_t MPerBlock,
index_t LPerBlock, index_t LPerBlock,
index_t K0PerBlock, // K0 * K1Value = Gemm0 GEMM_K Dim index_t K0PerBlock, // K0 * K1Value = Gemm0 GEMM_K Dim
index_t K1Value, index_t K1Value,
index_t NPerBlock, index_t NPerBlock,
index_t L0PerBlock, index_t L0PerBlock,
index_t L1Value, index_t L1Value,
index_t MPerWmma, index_t MPerWmma,
index_t LPerWmma, index_t LPerWmma,
index_t NPerWmma, index_t NPerWmma,
...@@ -209,8 +209,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -209,8 +209,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1(const A0BlockDesc_AK0_M_AK1&) MakeA0BlockDescriptor_K0_M0_M1_M2_K1(const A0BlockDesc_AK0_M_AK1&)
{ {
constexpr index_t A_K0 = A0BlockDesc_AK0_M_AK1{}.GetLength(I0); constexpr index_t A_K0 = A0BlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = A0BlockDesc_AK0_M_AK1{}.GetLength(I2); constexpr index_t A_K1 = A0BlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -227,8 +227,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -227,8 +227,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&) MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&)
{ {
constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0); constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0);
constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2); constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2);
constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma); constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma);
return transform_tensor_descriptor( return transform_tensor_descriptor(
B0BlockDesc_BK0_L_BK1{}, B0BlockDesc_BK0_L_BK1{},
...@@ -250,18 +250,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -250,18 +250,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return transform_tensor_descriptor( return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{}, A1BlockDesc_AL0_M_AL1{},
make_tuple(make_pass_through_transform(Number<A_L0>{}), make_tuple(make_pass_through_transform(Number<A_L0>{}),
make_unmerge_transform( make_unmerge_transform(make_tuple(Number<MRepeat>{}, I1, I1)),
make_tuple(Number<MRepeat>{}, I1, I1)),
make_pass_through_transform(Number<A_L1>{})), make_pass_through_transform(Number<A_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
template <typename B1BlockDesc_BL0_N_BL1> template <typename B1BlockDesc_BL0_N_BL1>
__host__ __device__ static constexpr auto MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&) __host__ __device__ static constexpr auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
{ {
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0); constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2); constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor( return transform_tensor_descriptor(
B1BlockDesc_BL0_N_BL1{}, B1BlockDesc_BL0_N_BL1{},
...@@ -317,17 +317,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -317,17 +317,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) + const index_t gemm0_bytes_end =
SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0)); (SharedMemTrait::a_block_space_size_aligned * sizeof(FloatA) +
SharedMemTrait::b0_block_space_size_aligned * sizeof(FloatB0));
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatB1); sizeof(FloatB1);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatAcc0); sizeof(FloatAcc0);
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
...@@ -360,8 +361,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -360,8 +361,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return false; return false;
} }
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
N % NPerBlock == 0))
{ {
return false; return false;
} }
...@@ -432,7 +432,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -432,7 +432,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -453,7 +453,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -453,7 +453,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b1_block_desc_bl0_n_bl1.GetElementSpaceSize(), max_lds_align); b1_block_desc_bl0_n_bl1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b0_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction // LDS allocation for reduction
...@@ -466,10 +466,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -466,10 +466,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = static constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
static constexpr auto c_block_space_size = static constexpr auto c_block_space_size =
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(); c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
}; };
template <bool HasMainKBlockLoop, typename C0MatrixMask, typename Block2CTileMap = DefaultBlock2CTileMap> template <bool HasMainKBlockLoop,
typename C0MatrixMask,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void Run(const FloatA* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatB0* __restrict__ p_b0_grid, const FloatB0* __restrict__ p_b0_grid,
const FloatB1* __restrict__ p_b1_grid, const FloatB1* __restrict__ p_b1_grid,
......
...@@ -165,7 +165,7 @@ __global__ void ...@@ -165,7 +165,7 @@ __global__ void
static constexpr index_t NumDTensor = static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
DsPointer p_ds_grid_grp; DsPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
...@@ -530,7 +530,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -530,7 +530,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
template <typename DsGridDesc_M_N_> template <typename DsGridDesc_M_N_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n) MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
......
...@@ -141,10 +141,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -141,10 +141,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1(const ABlockDesc_AK0_M_AK1&) __host__ __device__ static constexpr auto
MakeABlockDescriptor_K0_M0_M1_M2_K1(const ABlockDesc_AK0_M_AK1&)
{ {
constexpr index_t A_K0 = ABlockDesc_AK0_M_AK1{}.GetLength(I0); constexpr index_t A_K0 = ABlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = ABlockDesc_AK0_M_AK1{}.GetLength(I2); constexpr index_t A_K1 = ABlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -157,11 +158,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -157,11 +158,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
template <typename BBlockDesc_BK0_N_BK1> template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&) __host__ __device__ static constexpr auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{ {
constexpr index_t B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0); constexpr index_t B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr index_t B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2); constexpr index_t B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor( return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{}, BBlockDesc_BK0_N_BK1{},
......
...@@ -1311,11 +1311,11 @@ template <typename SrcData, ...@@ -1311,11 +1311,11 @@ template <typename SrcData,
typename ElementwiseOperation, typename ElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
uint32_t LowEightRowlaneIdx, uint32_t LowEightRowlaneIdx,
uint32_t HighEightRowLaneIdx, uint32_t HighEightRowLaneIdx,
bool IntraRowSwizzlePerm, bool IntraRowSwizzlePerm,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...@@ -1383,7 +1383,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1383,7 +1383,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// idx_md err. as dst access 2 strided elements while src visit 1 per loop // idx_md err. as dst access 2 strided elements while src visit 1 per loop
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
...@@ -1398,24 +1398,37 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1398,24 +1398,37 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
element_op_(v_this_row, src_buf[Number<src_offset>{}]); element_op_(v_this_row, src_buf[Number<src_offset>{}]);
// apply intra-row swizzle permute // apply intra-row swizzle permute
if constexpr(IntraRowSwizzlePerm){ if constexpr(IntraRowSwizzlePerm)
// origin: 0xfedcba98, 0x76543210 {
temp = __builtin_amdgcn_permlane16(temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0); // origin:
// 0xfedcba98,
// 0x76543210
temp = __builtin_amdgcn_permlane16(
temp, type_convert<int>(v_this_row), 0xeca86420, 0xfdb97531, 1, 0);
v_this_row = type_convert<float>(temp); v_this_row = type_convert<float>(temp);
} }
// apply inter-row permute. // apply inter-row permute.
temp = __builtin_amdgcn_permlanex16(temp, type_convert<int>(v_this_row), LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0); temp = __builtin_amdgcn_permlanex16(temp,
type_convert<int>(v_this_row),
LowEightRowlaneIdx,
HighEightRowLaneIdx,
1,
0);
v_theother_row = type_convert<float>(temp); v_theother_row = type_convert<float>(temp);
if(get_thread_local_1d_id() % 32 < 16){ if(get_thread_local_1d_id() % 32 < 16)
{
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_this_row); dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_this_row);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v_theother_row); dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
type_convert<DstData>(v_theother_row);
} }
else{ else
{
// apply type convert // apply type convert
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v_this_row); dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
type_convert<DstData>(v_this_row);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_theother_row); dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v_theother_row);
} }
}); });
......
...@@ -444,7 +444,7 @@ struct WmmaGemm ...@@ -444,7 +444,7 @@ struct WmmaGemm
make_pass_through_transform(MWave), make_pass_through_transform(MWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}), make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
make_pass_through_transform(NBlockxRepeat), make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave), make_pass_through_transform(NWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{}, make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{}))), Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
......
...@@ -24,10 +24,9 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> ...@@ -24,10 +24,9 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
// * Inline assembly need to elimate the duplicated data load, compiler won't help you // * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them. // delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32( // amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{})); // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
reg_c.template AsType<float8_t>()(Number<0>{}) = reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
AsType<float8_t>()[Number<0>{}]);
} }
}; };
......
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' # git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment