Commit c8f6d5d1 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 7b4de775
...@@ -31,7 +31,7 @@ __global__ void ...@@ -31,7 +31,7 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdlops_v2r3( kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_, const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -44,31 +44,31 @@ __global__ void ...@@ -44,31 +44,31 @@ __global__ void
#if 1 #if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) { static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ && if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count) i < group_count)
{ {
auto group_id = i; auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_[group_id].a_ptr, gemm_descs[group_id].a_ptr,
gemm_desc_[group_id].b_ptr, gemm_descs[group_id].b_ptr,
gemm_desc_[group_id].c_ptr, gemm_descs[group_id].c_ptr,
p_shared, p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_, gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_, gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_); gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
} }
}); });
#else #else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_); const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
index_t group_id = 0; index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd && group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
i < group_count) i < group_count)
? i ? i
: group_id; : group_id;
...@@ -91,7 +91,7 @@ __global__ void ...@@ -91,7 +91,7 @@ __global__ void
block_id_grp); block_id_grp);
#endif #endif
#else #else
ignore = gemm_desc_; ignore = gemm_descs;
ignore = group_count; ignore = group_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
......
...@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32> ...@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x16_t>()(Number<0>{}) = reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a), __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
bit_cast<int>(reg_b), bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}], reg_c.template AsType<int32x16_t>()[Number<0>{}],
0, 0,
0, 0,
...@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> ...@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a), __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
bit_cast<int>(reg_b), bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
0, 0,
0, 0,
......
...@@ -169,8 +169,8 @@ check_err(const std::vector<T>& out, ...@@ -169,8 +169,8 @@ check_err(const std::vector<T>& out,
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const int64_t out_v = static_cast<int64_t>(out[i]); const auto out_v = static_cast<int64_t>(out[i]);
const int64_t ref_v = static_cast<int64_t>(ref[i]); const auto ref_v = static_cast<int64_t>(ref[i]);
if(out_v != ref_v) if(out_v != ref_v)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment