Commit 959ddcf8 authored by Jing Zhang's avatar Jing Zhang
Browse files

clang-format

parent db2e4cf4
...@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
...@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
...@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded; karg.KPadded = k_padded;
karg.K0Padded = k0_padded; karg.K0Padded = k0_padded;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
......
...@@ -133,8 +133,6 @@ void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances( ...@@ -133,8 +133,6 @@ void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -199,13 +197,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -199,13 +197,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> && else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
} }
} }
return op_ptrs; return op_ptrs;
......
...@@ -27,7 +27,7 @@ enum struct GemmDataType ...@@ -27,7 +27,7 @@ enum struct GemmDataType
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_BF16_BF16, // 2 BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
F16_F8_F16, // 4 F16_F8_F16, // 4
}; };
#define OP_NAME "grouped_gemm" #define OP_NAME "grouped_gemm"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment