Commit 02bf2be0 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent dfbe7e20
......@@ -13,6 +13,9 @@ namespace ck {
// GemmK = C * Y * X
template <index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmKPerWave,
typename... Wei,
typename... In,
typename... Out,
......@@ -106,9 +109,17 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
constexpr auto xdlops_gemm = XdlopsGemm<float, GemmMPerWave, GemmNPerWave, GemmKPerWave>{};
constexpr auto CLayout = xdlops_gemm.GetOutputLayout();
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / 8, 2, 4)),
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
......
......@@ -26,7 +26,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto XdlopsGemm = XdlopsGemm_t<float, MPerWave, NPerWave, KPerWave>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPerWave>{};
static constexpr index_t WaveSize = 64;
......@@ -35,16 +35,16 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t MWaves = MPerBlock / MPerWave;
static constexpr index_t NWaves = NPerBlock / NPerWave;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); }
__device__ constexpr auto GetOutputLayout() const { return xdlops_gemm.GetOutputLayout(); }
__device__ constexpr auto GetNumBlks() const
{
return XdlopsGemm.GetOutputLayout().GetNumBlks();
return xdlops_gemm.GetOutputLayout().GetNumBlks();
}
__device__ constexpr auto GetBlkSize() const
{
return XdlopsGemm.GetOutputLayout().GetBlkSize();
return xdlops_gemm.GetOutputLayout().GetBlkSize();
}
__device__ static auto CalculateAThreadOriginDataIndex()
......@@ -75,7 +75,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm.GetBeginOfThreadBlk(blk_i);
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(blk_i);
const index_t row = (waveId / NWaves) * AStride + thread_mtx_on_blk.row;
const index_t col = (waveId % NWaves) * BStride + thread_mtx_on_blk.col;
......@@ -127,7 +127,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
make_tuple(I0, I0),
b_thread_buf);
XdlopsGemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf);
xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf);
});
}
......
......@@ -333,7 +333,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
vector_type<float, 64> c_thread_buf;
constexpr auto c_vec_size = MPerBlock * NPerBlock / BlockSize;
vector_type<float, c_vec_size> c_thread_buf;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
......@@ -466,15 +468,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
constexpr auto OutputLayout = blockwise_gemm.GetOutputLayout();
constexpr index_t K0 = OutputLayout.M1();
constexpr index_t K1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0();
constexpr index_t M0 = OutputLayout.M1();
constexpr index_t M1 = OutputLayout.N1();
constexpr index_t M2 = OutputLayout.M0();
// static_assert(K0 == 4 && K1 == 2 && K2 == 4, "");
// static_assert(M0 == 4 && M1 == 2 && M2 == 4, "");
constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<K0>{}, Number<1>{}, Number<K2>{}, Number<1>{}));
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks();
......@@ -508,16 +510,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_global_desc),
Sequence<K0, 1, K2, 1>,
Sequence<M0, 1, M2, 1>,
Sequence<0, 1, 2, 3>, // CThreadTransferSrcDstAccessOrder,
3, // CThreadTransferSrcDstVectorDim,
1, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m1_m2_n_global_desc,
make_multi_index(k_thread_data_on_global / (K2 * K1),
k_thread_data_on_global % (K2 * K1) / K2,
k_thread_data_on_global % K2,
make_multi_index(k_thread_data_on_global / (M2 * M1),
k_thread_data_on_global % (M2 * M1) / M2,
k_thread_data_on_global % M2,
b_thread_data_on_global)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0),
......
......@@ -53,7 +53,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::run(a, b, reg_c);
return intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
......@@ -74,19 +74,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
const auto p_a = reinterpret_cast<const float*>(a);
const auto p_b = reinterpret_cast<const float*>(b);
return intrin_mfma_f32_32x32x2f32(p_a, p_b, reg_c);
return intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
......@@ -548,7 +539,7 @@ struct xdlops_info
};
template <class data_type, index_t MPerWave, index_t NPerWave, index_t KPerWave>
struct XdlopsGemm_t
struct XdlopsGemm
{
struct MatrixIndex
{
......@@ -561,7 +552,7 @@ struct XdlopsGemm_t
return (MPerXdlops * NPerXdlops) / (mfma_type.m * mfma_type.n);
}
__device__ constexpr XdlopsGemm_t()
__host__ __device__ constexpr XdlopsGemm()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64,
......@@ -849,10 +840,10 @@ struct XdlopsGemm_t
struct OutputLayout
{
__device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__device__ static constexpr index_t M0() { return mfma_type.group_size; }
__device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
......@@ -867,7 +858,7 @@ struct XdlopsGemm_t
}
};
__device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
__host__ __device__ static constexpr auto GetOutputLayout() { return OutputLayout{}; }
};
} // namespace ck
......
......@@ -241,7 +241,7 @@ template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
{
__device__ static void
run(const float& reg_a, const float& reg_b, vector_type<float, 64>& reg_c)
Run(const float& reg_a, const float& reg_b, vector_type<float, 64>& reg_c)
{
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
......@@ -272,12 +272,19 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
//}
//};
__device__ c_vec16_1_t::VecType
intrin_mfma_f32_32x32x2f32(const float* reg_a, const float* reg_b, c_vec16_1_t::VecType reg_c)
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x2f32;
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ static void
Run(const float& reg_a, const float& reg_b, vector_type<float, 16>& reg_c)
{
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
}
};
__device__ c_vec4_1_t::VecType
intrin_mfma_f32_16x16x4f32(const float* reg_a, const float* reg_b, c_vec4_1_t::VecType reg_c)
......
......@@ -71,13 +71,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(OutDesc::GetLengths()));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif
// b thread copy 4x1
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
......@@ -101,13 +101,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 32;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmKPerWave = 2;
constexpr index_t GemmM1 = GemmMPerWave;
constexpr index_t GemmN1 = GemmNPerWave;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock,
GemmNPerBlock>(
GemmNPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPerWave>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment