Commit 55599afd authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 4cf69087
...@@ -434,10 +434,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -434,10 +434,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx[I0]; constexpr auto K = a_block_mtx.GetLength(I0);
constexpr index_t MPerThread = c_thread_mtx[I0]; constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr index_t NPerThread = c_thread_mtx[I1]; constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster = constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
......
...@@ -57,26 +57,26 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -57,26 +57,26 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
BBlockTransferDstScalarPerVector_N, Number<BBlockTransferDstScalarPerVector_N>{},
MPerThread, Number<MPerThread>{},
NPerThread); Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_multi_index(KPerBlock, MPerBlock), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_multi_index(KPerBlock, NPerBlock), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
...@@ -96,14 +96,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -96,14 +96,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
const index_t K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_global_desc.GetLength(I0);
const index_t M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_global_desc.GetLength(I1);
const index_t N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 0
const index_t m_block_work_num = M / MPerBlock; const auto m_block_work_num = M / Number<MPerBlock>{};
const index_t n_block_work_num = N / NPerBlock; const auto n_block_work_num = N / Number<NPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num; const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
...@@ -122,20 +122,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -122,20 +122,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
const index_t n_block_data_on_global = n_block_work_id * NPerBlock; const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment // lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
BBlockTransferDstScalarPerVector_N, Number<BBlockTransferDstScalarPerVector_N>{},
MPerThread, Number<MPerThread>{},
NPerThread); Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), Number<max_lds_align>{}); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), Number<max_lds_align>{}); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -230,10 +230,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -230,10 +230,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
NPerThread>{}; NPerThread>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size = constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
...@@ -372,8 +372,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -372,8 +372,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy // define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy // thread input tensor, src of threadwise copy
......
...@@ -114,8 +114,7 @@ __host__ __device__ constexpr T min(T x, Ts... xs) ...@@ -114,8 +114,7 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
} }
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
template <typename T> __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
__host__ __device__ constexpr T gcd(T x, T y)
{ {
if(x == y || x == 0) if(x == y || x == 0)
{ {
...@@ -143,7 +142,9 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) ...@@ -143,7 +142,9 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return Number<r>{}; return Number<r>{};
} }
template <typename X, typename... Ys> template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return gcd(x, ys...); return gcd(x, ys...);
...@@ -156,7 +157,9 @@ __host__ __device__ constexpr auto lcm(X x, Y y) ...@@ -156,7 +157,9 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return (x * y) / gcd(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, typename... Ys> template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys) __host__ __device__ constexpr auto lcm(X x, Ys... ys)
{ {
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment