Commit c8f6d5d1 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 7b4de775
......@@ -31,7 +31,7 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -44,31 +44,31 @@ __global__ void
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ &&
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_);
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
i < group_count)
? i
: group_id;
......@@ -91,7 +91,7 @@ __global__ void
block_id_grp);
#endif
#else
ignore = gemm_desc_;
ignore = gemm_descs;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
......
......@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
bit_cast<int>(reg_b),
__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x16_t>()[Number<0>{}],
0,
0,
......@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
bit_cast<int>(reg_b),
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
bit_cast<int32_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
......
......@@ -169,8 +169,8 @@ check_err(const std::vector<T>& out,
for(std::size_t i = 0; i < ref.size(); ++i)
{
const int64_t out_v = static_cast<int64_t>(out[i]);
const int64_t ref_v = static_cast<int64_t>(ref[i]);
const auto out_v = static_cast<int64_t>(out[i]);
const auto ref_v = static_cast<int64_t>(ref[i]);
if(out_v != ref_v)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment