Commit b4d598bd authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 06ba0a90
...@@ -118,19 +118,27 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -118,19 +118,27 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting"); throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting");
} }
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
// c_m0_m10_m11_n0_n10_n11_grid_desc // c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// c_block_cluster_adaptor // c_blockid_to_m0_n0_block_cluster_adaptor
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
...@@ -142,13 +150,16 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -142,13 +150,16 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm, const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
true>; true>;
...@@ -163,18 +174,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -163,18 +174,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor); c_blockid_to_m0_n0_block_cluster_adaptor);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm, const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true, true,
false>; false>;
...@@ -189,18 +205,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -189,18 +205,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor); c_blockid_to_m0_n0_block_cluster_adaptor);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm, const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
true>; true>;
...@@ -215,18 +236,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -215,18 +236,23 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor); c_blockid_to_m0_n0_block_cluster_adaptor);
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm, const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>, remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false, false,
false>; false>;
...@@ -241,8 +267,10 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -241,8 +267,10 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor); c_blockid_to_m0_n0_block_cluster_adaptor);
} }
return ave_time; return ave_time;
......
...@@ -140,7 +140,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -140,7 +140,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
public: public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id())}, : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{ b_thread_copy_{
...@@ -161,14 +162,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -161,14 +162,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static_assert(M0 == 2 && N0 == 2, "wrong"); static_assert(M0 == 2 && N0 == 2, "wrong");
} }
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginIndex(index_t thread_id) __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
{ {
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1] // lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// upper: [Tid, M0, M11, N0, N11]
// lower: [M0, M100, M101, M11, N0, N100, N101, N11] // lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor( constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0), make_pass_through_transform(M0),
......
...@@ -17,21 +17,26 @@ template <typename GridwiseGemm, ...@@ -17,21 +17,26 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc, typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterAdaptor, typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r2(const FloatAB* __restrict__ p_a_grid, kernel_dynamic_gemm_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc, const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc, const BKNGridDesc b_k_n_grid_desc,
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterAdaptor c_block_cluster_desc) const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -44,8 +49,10 @@ __global__ void ...@@ -44,8 +49,10 @@ __global__ void
p_shared_block, p_shared_block,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc, c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -227,7 +234,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -227,7 +234,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
...@@ -238,25 +245,30 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -238,25 +245,30 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
const auto c_block_cluster_adaptor = make_cluster_descriptor_v2(make_tuple(M0, N0)); const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_cluster_descriptor_v2(make_tuple(M0, N0));
return c_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_desc, const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
...@@ -271,15 +283,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -271,15 +283,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx = c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t m_block_work_idx = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t m_block_work_idx = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
...@@ -568,7 +574,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -568,7 +574,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I3]>{})); Number<c_m10_n10_m11_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id()); blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment