Commit 5dd45128 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 02d23347
...@@ -78,26 +78,30 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -78,26 +78,30 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
const auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}; const auto M1Old = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
const auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}; const auto N1Old = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) if(!(MPerBlock % M1Old == 0 && NPerBlock % N1Old == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
const auto M0 = M / M1; const auto M0Old = M / M1Old;
const auto N0 = N / N1; const auto N0Old = N / N1Old;
const auto c_m0_m1_n0_n1_grid_desc = const auto c_m0_m1_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
transform_dynamic_tensor_descriptor(c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M1)), make_tuple(make_unmerge_transform(make_tuple(M0Old, M1Old)),
make_unmerge_transform(make_tuple(N0, N1))), make_unmerge_transform(make_tuple(N0Old, N1Old))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc); using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
#if 0
const auto c_m0_m10_m
#endif
// out_gemm_block_cluster_desc // out_gemm_block_cluster_desc
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
make_cluster_descriptor_v2(make_tuple(M / Number<MPerBlock>{}, N / Number<NPerBlock>{})); make_cluster_descriptor_v2(make_tuple(M / Number<MPerBlock>{}, N / Number<NPerBlock>{}));
......
...@@ -54,6 +54,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -54,6 +54,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1); static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1); static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
...@@ -86,9 +95,44 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -86,9 +95,44 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
return b_k_n0_n1_block_desc; return b_k_n0_n1_block_desc;
} }
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
{ {
return Sequence<M0, M1PerThreadM11, N0, N1PerThreadN11>{}; return Sequence<M0, M11, N0, N11>{};
} }
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{}); static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
...@@ -96,7 +140,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -96,7 +140,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
public: public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())}, : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{ b_thread_copy_{
...@@ -105,8 +149,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -105,8 +149,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(), static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == M1N1ThreadClusterM101 * M1N1ThreadClusterM100 * static_assert(BlockSize == M101 * M100 * N101 * N100,
M1N1ThreadClusterN101 * M1N1ThreadClusterN100,
"wrong! blocksize and cluster size not consistent"); "wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
...@@ -118,39 +161,27 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -118,39 +161,27 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
static_assert(M0 == 2 && N0 == 2, "wrong"); static_assert(M0 == 2 && N0 == 2, "wrong");
} }
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id) __device__ static CIndex CalculateCM0M1N0N1ThreadOriginIndex(index_t thread_id)
{ {
// 4-d data space into 4-d thread space // upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = make_single_stage_tensor_adaptor( // lower: [M0, M1, N0, N1]
make_tuple(make_vectorize_transform(M0, 1), constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
make_vectorize_transform(M1PerThreadM11, M1 / M1PerThreadM11),
make_vectorize_transform(N0, 1),
make_vectorize_transform(N1PerThreadN11, N1 / N1PerThreadN11)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space // upper: [Tid, M0, M11, N0, N11]
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor( constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple( make_tuple(
make_freeze_transform(make_multi_index(0)), Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_unmerge_transform(make_tuple(M1N1ThreadClusterM100, M1N1ThreadClusterM101)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterN100, M1N1ThreadClusterN101))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM100,
M1N1ThreadClusterN100,
M1N1ThreadClusterM101,
M1N1ThreadClusterN101))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2); constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
} }
template <typename CM0M1N0N1ThreadDesc, template <typename CM0M1N0N1ThreadDesc,
......
...@@ -220,10 +220,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -220,10 +220,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc, const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
#if 0
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
#endif
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
...@@ -508,7 +504,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -508,7 +504,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGridIteratorHacks{}; constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGridIteratorHacks{};
const auto c_thread_data_idx_on_block = const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id()); blockwise_gemm.CalculateCM0M1N0N1ThreadOriginIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
// GPU ID // GPU ID
#if 0 #if 1
#define CK_AMD_GPU_GFX906 1 #define CK_AMD_GPU_GFX906 1
#elif 0 #elif 0
#define CK_AMD_GPU_GFX908 1 #define CK_AMD_GPU_GFX908 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment