Commit 55599afd authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 4cf69087
......@@ -434,10 +434,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx[I0];
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr index_t MPerThread = c_thread_mtx[I0];
constexpr index_t NPerThread = c_thread_mtx[I1];
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
......
......@@ -57,26 +57,26 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
......@@ -96,14 +96,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const index_t K = a_k_m_global_desc.GetLength(I0);
const index_t M = a_k_m_global_desc.GetLength(I1);
const index_t N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
#if 0
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
const auto m_block_work_num = M / Number<MPerBlock>{};
const auto n_block_work_num = N / Number<NPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
......@@ -122,20 +122,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), Number<max_lds_align>{});
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), Number<max_lds_align>{});
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
......@@ -230,10 +230,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block;
......@@ -372,8 +372,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// output: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
......
......@@ -114,8 +114,7 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
}
// greatest common divisor, aka highest common factor
template <typename T>
__host__ __device__ constexpr T gcd(T x, T y)
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
if(x == y || x == 0)
{
......@@ -143,7 +142,9 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return Number<r>{};
}
template <typename X, typename... Ys>
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, ys...);
......@@ -156,7 +157,9 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return (x * y) / gcd(x, y);
}
template <typename X, typename... Ys>
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment