Commit 15d96340 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers sugestions

parent d976670e
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp> #include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
...@@ -237,12 +236,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -237,12 +236,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise = using GridwiseElementwise =
GridwiseElementwise<CDGridDesc_M_N, GridwiseElementwise<CDGridDesc_M_N,
ck::Tuple<EGridDesc_M_N>, ck::Tuple<EGridDesc_M_N>,
CDDataTypes, CDDataTypes,
ck::Tuple<EDataType*>, ck::Tuple<EDataType*>,
Block2ETileMapKSplit, Block2TileMap,
CDEElementwiseOperation, CDEElementwiseOperation,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -737,7 +737,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -737,7 +737,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
ck::Tuple<EGridDesc_M_N>, ck::Tuple<EGridDesc_M_N>,
CDDataTypes, CDDataTypes,
ck::Tuple<EDataType*>, ck::Tuple<EDataType*>,
Block2ETileMapKSplit, Block2TileMap,
CDEElementwiseOperation>; CDEElementwiseOperation>;
return LaunchKernel(gemm_kernel, return LaunchKernel(gemm_kernel,
elementwise_kernel, elementwise_kernel,
...@@ -791,8 +791,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -791,8 +791,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid),
arg.ds_grid_pointer_[i]), arg.ds_grid_pointer_[i]),
type_convert<EDataType*>(arg.e_ptrs_[i]), type_convert<EDataType*>(arg.e_ptrs_[i]),
Block2ETileMapKSplit{ Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0),
arg.elementwise_c_grid_descs_m_n_[i], B2E_M01, arg.K_BATCH}, arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)},
arg.cde_element_op_); arg.cde_element_op_);
} }
return time; return time;
......
...@@ -206,6 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -206,6 +206,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
...@@ -238,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -238,6 +239,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
} }
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> && else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
...@@ -256,6 +259,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -256,6 +259,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
} }
} }
#endif
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> && else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>) is_same_v<EDataType, bhalf_t>)
{ {
...@@ -266,6 +271,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -266,6 +271,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs); op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -99,7 +99,6 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[]) ...@@ -99,7 +99,6 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
n_iter = std::stoi(argv[17]); n_iter = std::stoi(argv[17]);
} }
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t, ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t,
...@@ -150,7 +149,6 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[]) ...@@ -150,7 +149,6 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
} }
#endif
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment