Commit 15d96340 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers sugestions

parent d976670e
......@@ -18,7 +18,6 @@
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
......
......@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
......@@ -237,12 +236,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise =
GridwiseElementwise<CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
Block2ETileMapKSplit,
Block2TileMap,
CDEElementwiseOperation,
BlockSize,
MPerBlock,
......@@ -737,7 +737,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
ck::Tuple<EGridDesc_M_N>,
CDDataTypes,
ck::Tuple<EDataType*>,
Block2ETileMapKSplit,
Block2TileMap,
CDEElementwiseOperation>;
return LaunchKernel(gemm_kernel,
elementwise_kernel,
......@@ -791,8 +791,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
arg.ds_grid_pointer_[i]),
type_convert<EDataType*>(arg.e_ptrs_[i]),
Block2ETileMapKSplit{
arg.elementwise_c_grid_descs_m_n_[i], B2E_M01, arg.K_BATCH},
Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0),
arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)},
arg.cde_element_op_);
}
return time;
......
......@@ -206,6 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
......@@ -238,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
......@@ -256,6 +259,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
......@@ -266,6 +271,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -99,7 +99,6 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
n_iter = std::stoi(argv[17]);
}
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t,
......@@ -150,7 +149,6 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
#endif
return 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment