Commit d3341a67 authored by Jing Zhang's avatar Jing Zhang
Browse files

xdlops refactor

parent b62bf8c3
...@@ -709,19 +709,59 @@ struct XdlopsGemm ...@@ -709,19 +709,59 @@ struct XdlopsGemm
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
} }
template <typename CM0N0M1N1M2N2GridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CM0N0M1N1M2N2GridDesc& c_m0_n0_m1_n1_m2_n2_grid_desc)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto M0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I4);
constexpr auto N2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I5);
static_assert(N2 == mfma_type.num_threads_blk, "");
static_assert(
M2 == (mfma_type.num_groups_blk * mfma_type.num_output_blks * mfma_type.group_size),
"");
return transform_dynamic_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_grid_desc,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_type.num_groups_blk,
mfma_type.num_input_blks,
mfma_type.group_size)),
make_pass_through_transform(mfma_type.num_threads_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
__device__ static constexpr index_t GetRegSizePerXdlops() __device__ static constexpr index_t GetRegSizePerXdlops()
{ {
return MPerXdlops * NPerXdlops / mfma_type.wave_size; return MPerXdlops * NPerXdlops / mfma_type.wave_size;
} }
template <class ADesc, template <index_t c_offset, class FloatA, class FloatB, class FloatC>
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
...@@ -730,24 +770,35 @@ struct XdlopsGemm ...@@ -730,24 +770,35 @@ struct XdlopsGemm
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); static_for<0, KPack / mfma_type.k_base, 1>{}([&](auto k) {
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>( mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset / mfma_type.k_base>{}], p_a_wave[k], p_b_wave[k], p_c_thread);
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
}); });
} }
static constexpr auto GetBlkIdx()
{
const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_type.num_input_blks, mfma_type.num_threads_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto blk_idx = threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto blk_id = blk_idx[Number<1>{}];
const auto blk_td = blk_idx[Number<2>{}];
return make_tuple(blk_id, blk_td);
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{ {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; const auto blk_idx = GetBlkIdx();
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk; const auto blk_id = blk_idx[Number<0>{}];
const auto blk_td = blk_idx[Number<1>{}];
index_t n_offset = blk_i * mfma_type.n + blk_td; index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
...@@ -755,24 +806,12 @@ struct XdlopsGemm ...@@ -755,24 +806,12 @@ struct XdlopsGemm
return CIndex{m_offset, n_offset}; return CIndex{m_offset, n_offset};
} }
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id)
{
return lane_id / mfma_type.num_threads_blk;
}
static constexpr auto GetBlkTd(const index_t lane_id)
{
return lane_id % mfma_type.num_threads_blk;
}
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
...@@ -794,7 +833,7 @@ struct XdlopsGemm ...@@ -794,7 +833,7 @@ struct XdlopsGemm
} }
}; };
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } __host__ __device__ static constexpr auto GetCXdlopsLayout() { return CLayout{}; }
}; };
} // namespace ck } // namespace ck
......
...@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
...@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
...@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( float ave_time = launch_and_time_kernel(
...@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif #endif
return ave_time; return ave_time;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#define USE_CONV_FWD_V4R4R2_NHWC 1 #define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum ConvForwardAlgo enum ConvForwardAlgo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment