Commit b4d598bd authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 06ba0a90
......@@ -118,19 +118,27 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting");
throw std::runtime_error(
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 has invalid setting");
}
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
// c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// c_block_cluster_adaptor
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
// c_blockid_to_m0_n0_block_cluster_adaptor
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
......@@ -142,15 +150,18 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true,
true>;
const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -163,20 +174,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor);
c_blockid_to_m0_n0_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true,
false>;
const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -189,20 +205,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor);
c_blockid_to_m0_n0_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false,
true>;
const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -215,20 +236,25 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor);
c_blockid_to_m0_n0_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false,
false>;
const auto kernel =
kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<AKM0M1GridDesc>,
remove_reference_t<BKN0N1GridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -241,8 +267,10 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_adaptor);
c_blockid_to_m0_n0_block_cluster_adaptor);
}
return ave_time;
......
......@@ -140,7 +140,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id())},
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
......@@ -161,14 +162,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static_assert(M0 == 2 && N0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginIndex(index_t thread_id)
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// upper: [Tid, M0, M11, N0, N11]
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
......
......@@ -17,21 +17,26 @@ template <typename GridwiseGemm,
typename FloatC,
typename AKMGridDesc,
typename BKNGridDesc,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterAdaptor,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterAdaptor c_block_cluster_desc)
kernel_dynamic_gemm_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -44,8 +49,10 @@ __global__ void
p_shared_block,
a_k_m_grid_desc,
b_k_n_grid_desc,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
......@@ -227,7 +234,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
......@@ -238,27 +245,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_block_cluster_adaptor = make_cluster_descriptor_v2(make_tuple(M0, N0));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_cluster_descriptor_v2(make_tuple(M0, N0));
return c_block_cluster_adaptor;
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_k_m_grid_desc.GetElementSpaceSize());
......@@ -271,15 +283,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const auto block_work_idx = c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t m_block_work_idx = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
......@@ -568,7 +574,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
Number<c_m10_n10_m11_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id());
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment