Commit e628b162 authored by Adam Osewski's avatar Adam Osewski
Browse files

Code cleanup.

parent 2a16c61c
...@@ -128,21 +128,12 @@ __global__ void ...@@ -128,21 +128,12 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
// results_buffer.Clear();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
// do
// {
auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx()); auto k_tiles = work_scheduler.GetNextKTiles(k_batch, b2c_tile_map.GetTileKIdx());
// if (blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
// {
// printf("bid: %d, k_tiles: %d\n",
// static_cast<index_t>(blockIdx.x),
// k_tiles);
// }
// just accumulate results in registers! // just accumulate results in registers!
GridwiseGemm::template RunGEMM(p_a_grid, GridwiseGemm::template RunGEMM(p_a_grid,
p_b_grid, p_b_grid,
...@@ -161,8 +152,6 @@ __global__ void ...@@ -161,8 +152,6 @@ __global__ void
// Move to the last processed k-tile // Move to the last processed k-tile
b2c_tile_map.AdvanceTileKIdx(k_tiles - 1); b2c_tile_map.AdvanceTileKIdx(k_tiles - 1);
// } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
// With cshuffle at store partials all workgroups have to store // With cshuffle at store partials all workgroups have to store
// their partials to workspace gmem. // their partials to workspace gmem.
...@@ -220,11 +209,6 @@ __global__ void ...@@ -220,11 +209,6 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
// if (threadIdx.x == 0)
// {
// p_e_grid[blockIdx.x] = 0;
// }
GridwiseGemm::template RunWrite(p_ds_grid, GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid, p_e_grid,
acc_buff, acc_buff,
...@@ -766,16 +750,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -766,16 +750,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
} }
auto preprocess = [&]() { auto preprocess = [&]() {
// std::cout << "[preprocess] p_flags: " << p_flags
// << ", flag count: " << flag_count
// << ", bytes: " << flag_count * sizeof(uint32_t)
// << ", stream id: " << stream_config.stream_id_
// << std::endl;
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(
p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_)); p_flags, 0, flag_count * sizeof(uint32_t), stream_config.stream_id_));
// TODO: For debug only!
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 2, acc_workspace_size_bytes, stream_config.stream_id_));
}; };
return launch_and_time_kernel_with_preprocess( return launch_and_time_kernel_with_preprocess(
...@@ -967,11 +943,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -967,11 +943,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) + Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size) +
flag_count * sizeof(uint32_t); flag_count * sizeof(uint32_t);
std::cout << "[GetWorkspaceSize]: "
<< "occ_grid_size: " << occ_grid_size << ", grid_size: " << grid_size
<< ", tiles_per_block: " << tiles_per_block << ", flag_count: " << flag_count
<< ", size_bytes: " << size_bytes << std::endl;
return size_bytes; return size_bytes;
} }
......
...@@ -120,9 +120,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -120,9 +120,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
return math::integer_least_multiple(N, NPerBlock); return math::integer_least_multiple(N, NPerBlock);
} }
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) __host__ __device__ static auto CalculateKPadded(index_t K)
{ {
return math::integer_least_multiple(K, KPerBlock * K_Batch); return math::integer_least_multiple(K, KPerBlock);
} }
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
...@@ -142,7 +142,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -142,7 +142,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
__host__ __device__ static auto __host__ __device__ static auto
MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t StrideA)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -155,7 +155,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -155,7 +155,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
}(); }();
const auto KPad = CalculateKPadded(K, KBatch); const auto KPad = CalculateKPadded(K);
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
...@@ -190,7 +190,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -190,7 +190,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
__host__ __device__ static auto __host__ __device__ static auto
MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t StrideB)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -204,7 +204,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -204,7 +204,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}(); }();
const auto NPad = CalculateNPadded(N); const auto NPad = CalculateNPadded(N);
const auto KPad = CalculateKPadded(K, KBatch); const auto KPad = CalculateKPadded(K);
const auto b_grid_desc_kpad_n = transform_tensor_descriptor( const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_k_n,
...@@ -239,8 +239,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -239,8 +239,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
private: private:
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1))>; using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1))>;
using ABlockDesc_AK0PerB_MPerB_AK1 = using ABlockDesc_AK0PerB_MPerB_AK1 =
remove_cvref_t<decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())>; remove_cvref_t<decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())>;
...@@ -377,14 +377,41 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -377,14 +377,41 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const index_t StrideE, const index_t StrideE,
const index_t KBatch) const index_t KBatch)
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA, KBatch); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, StrideA);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB, KBatch); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, StrideB);
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE); const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || const auto IsMPadded = []() -> bool {
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
return true;
else
return false;
};
const auto IsNPadded = []() -> bool {
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
return true;
else
return false;
};
const auto IsKPadded = []() -> bool {
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
return true;
else
return false;
};
if constexpr(!IsMPadded())
{ {
if(!(M % MPerBlock == 0)) if(!(M % MPerBlock == 0))
{ {
...@@ -398,10 +425,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -398,10 +425,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
} }
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!IsNPadded())
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(N % NPerBlock == 0)) if(!(N % NPerBlock == 0))
{ {
...@@ -416,17 +440,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -416,17 +440,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
} }
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || if constexpr(!IsKPadded())
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(K % KPerBlock == 0)) if(!(K % KPerBlock == 0))
{ {
if(ck::EnvIsEnabled(ENV(CK_LOGGING))) if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "Arg K value is not a multiple of ! KBatch * KPerBlock: " << K std::cout << "Arg K value is not a multiple of ! KPerBlock: " << K << " "
<< " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
return false; return false;
...@@ -552,6 +573,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -552,6 +573,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
// This does not take into account that each WGP can run multiple kbatch tiles
// However that information is dynamic at kernel run-time.
// So this condition may be too restrictive.
const auto num_k_loop = const auto num_k_loop =
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(KPerBlock * KBatch); (KPerBlock * KBatch);
...@@ -562,8 +586,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -562,8 +586,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{ {
std::cout << "The number of k loops (" << num_k_loop std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline." << ") value is not supported by GridwiseGemm Pipeline."
<< " K0Padded: " << a_grid_desc_ak0_m_ak1.GetLength(I1) << __FILE__ << ":" << " AK0Padded: " << a_grid_desc_ak0_m_ak1.GetLength(I0) << __FILE__
<< __LINE__ << ", in function: " << __func__ << std::endl; << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
} }
return false; return false;
} }
...@@ -870,8 +894,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -870,8 +894,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, stride_a, k_batch); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(M, K, stride_a);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, stride_b, k_batch); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(K, N, stride_b);
RunGEMM(p_a_grid, RunGEMM(p_a_grid,
p_b_grid, p_b_grid,
...@@ -1242,33 +1266,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1242,33 +1266,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0m1_n0n1n2, partial_acc_load_step); acc_load.MoveSrcSliceWindow(workspace_grid_desc_m0m1_n0n1n2, partial_acc_load_step);
} }
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
// {
// printf("[bid: %d, tid: %d], {Accumulate Partials} AccBuf v[0, 0, 0, 0, 0-3]: [%f,
// %f,"
// "%f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// static_cast<float>(acc_buff[Number<0>{}]),
// static_cast<float>(acc_buff[Number<1>{}]),
// static_cast<float>(acc_buff[Number<2>{}]),
// static_cast<float>(acc_buff[Number<3>{}]));
// printf("[bid: %d, tid: %d], {Accumulate Partials} AccBuf v[0, 0, 0, 1, 0-3]: [%f,
// %f,"
// "%f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// static_cast<float>(acc_buff[Number<8>{}]),
// static_cast<float>(acc_buff[Number<9>{}]),
// static_cast<float>(acc_buff[Number<10>{}]),
// static_cast<float>(acc_buff[Number<11>{}]));
// }
} }
template <typename Block2ETileMap, typename AccumulationBuffer> template <typename Block2ETileMap, typename AccumulationBuffer>
__device__ static void RunWrite(DsGridPointer p_ds_grid, __device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
/* void* __restrict__ p_shared, */
AccumulationBuffer& acc_buff, AccumulationBuffer& acc_buff,
const index_t M, const index_t M,
const index_t N, const index_t N,
...@@ -1323,6 +1325,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1323,6 +1325,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1(); constexpr auto cluster_length_reduce = GetClusterLengthReduction_M0_N0N1();
constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce); constexpr auto reduce_cluster_desc = make_cluster_descriptor(cluster_length_reduce);
// TODO similar assertion
// static_assert(
// is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
// "wrong! threads should be mapped to cover entire slicing window");
static_assert(ThisThreadBlock::GetNumOfThread() >= reduce_cluster_desc.GetElementSize(), static_assert(ThisThreadBlock::GetNumOfThread() >= reduce_cluster_desc.GetElementSize(),
"Error! ThisThreadBlock::GetNumOfThread() too small"); "Error! ThisThreadBlock::GetNumOfThread() too small");
...@@ -1433,24 +1441,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1433,24 +1441,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
unpack2(cde_element_op, tie(acc_buff(acc_buf_offset + I)), src_data_refs); unpack2(cde_element_op, tie(acc_buff(acc_buf_offset + I)), src_data_refs);
}); });
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
// {
// printf("[bid: %d, tid: %d, m_iter: %d, n_iter: %d], {RunWrite} AuxBuf v[0-3]:
// "
// " [%f, %f, %f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// m_idx.value,
// n_idx.value,
// static_cast<float>(aux_vgpr_buf[Number<0>{}]),
// static_cast<float>(aux_vgpr_buf[Number<1>{}]),
// static_cast<float>(aux_vgpr_buf[Number<2>{}]),
// static_cast<float>(aux_vgpr_buf[Number<3>{}]));
// }
e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2, e_grid_store.Run(workspace_thread_desc_m0m1_n0n1n2,
make_tuple(I0, m_idx, I0, n_idx, I0), make_tuple(I0, m_idx, I0, n_idx, I0),
// aux_vgpr_buf,
acc_buff, acc_buff,
e_grid_desc_m0m1_n0n1n2, e_grid_desc_m0m1_n0n1n2,
e_grid_buf); e_grid_buf);
......
...@@ -31,47 +31,44 @@ void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregu ...@@ -31,47 +31,44 @@ void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregu
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// void void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1(
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1( std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, Row,
// Row, Empty_Tuple,
// Empty_Tuple, Row,
// Row, F16,
// F16, F16,
// F16, Empty_Tuple,
// Empty_Tuple, F16,
// F16, PassThrough,
// PassThrough, PassThrough,
// PassThrough, PassThrough>>>& instances);
// PassThrough>>>& instances);
// void void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave(
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave( std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, Row,
// Row, Empty_Tuple,
// Empty_Tuple, Row,
// Row, F16,
// F16, F16,
// F16, Empty_Tuple,
// Empty_Tuple, F16,
// F16, PassThrough,
// PassThrough, PassThrough,
// PassThrough, PassThrough>>>& instances);
// PassThrough>>>& instances);
// void void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2(
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2( std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, Row,
// Row, Empty_Tuple,
// Empty_Tuple, Row,
// Row, F16,
// F16, F16,
// F16, Empty_Tuple,
// Empty_Tuple, F16,
// F16, PassThrough,
// PassThrough, PassThrough,
// PassThrough, PassThrough>>>& instances);
// PassThrough>>>& instances);
#endif #endif
template <typename ALayout, template <typename ALayout,
...@@ -119,12 +116,12 @@ struct DeviceOperationInstanceFactory< ...@@ -119,12 +116,12 @@ struct DeviceOperationInstanceFactory<
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
#if defined(CK_ENABLE_FP16) #if defined(CK_ENABLE_FP16)
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1( add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1(
// op_ptrs); op_ptrs);
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave( add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v1_interwave(
// op_ptrs); op_ptrs);
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2( add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances_pipeline_v2(
// op_ptrs); op_ptrs);
#endif #endif
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
......
# ONLY XDL_KERNELS # ONLY XDL_KERNELS
add_instance_library(device_grouped_gemm_multiple_d_instance add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
# device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v1.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v1.cpp
# device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v1_interwave.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v1_interwave.cpp
# device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v2.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance_pipeline_v2.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment