"src/scanner.h" did not exist on "de29068110e3aee790661ca070073d9b60354c93"
Commit 3c7e8da2 authored by j4yan's avatar j4yan
Browse files

rename variables and functions in gridwise_gemm_dlops_v1r3

parent 2faeaece
...@@ -228,13 +228,13 @@ struct DeviceGemmDlops ...@@ -228,13 +228,13 @@ struct DeviceGemmDlops
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBK0N0N1K1GridDescriptor(BGridDesc_K0_N_K1{})); decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -261,10 +261,10 @@ struct DeviceGemmDlops ...@@ -261,10 +261,10 @@ struct DeviceGemmDlops
c_grid_desc_m0_m10_m11_n0_n10_n11_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01} N01_{N01},
// a_element_op_{a_element_op}, a_element_op_{a_element_op},
// b_element_op_{b_element_op}, b_element_op_{b_element_op},
// c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -274,14 +274,14 @@ struct DeviceGemmDlops ...@@ -274,14 +274,14 @@ struct DeviceGemmDlops
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{ {
a_grid_desc_k0_m0_m1_k1_ = a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_grid_desc_k0_m_k1_); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ = b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_grid_desc_k0_n_k1_); GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
} }
} }
...@@ -300,12 +300,14 @@ struct DeviceGemmDlops ...@@ -300,12 +300,14 @@ struct DeviceGemmDlops
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused, but may be useful in future.
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
// AElementwiseOperation a_element_op_; // TODO: unused since gridwise_gemm_dlops_v1r3 does NOT support prologue for the time being.
// BElementwiseOperation b_element_op_; AElementwiseOperation a_element_op_;
// CElementwiseOperation c_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
}; };
// Invoker // Invoker
...@@ -317,14 +319,14 @@ struct DeviceGemmDlops ...@@ -317,14 +319,14 @@ struct DeviceGemmDlops
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
......
...@@ -16,10 +16,10 @@ namespace ck { ...@@ -16,10 +16,10 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AK0M0M1K1GridDesc, typename AGridDesc_K0_M0_M1_K1,
typename BK0N0N1K1GridDesc, typename BGridDesc_K0_N0_N1_K1,
typename CM0M10M11N0N10N11GridDesc, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename CBlockIdToM0N0BlockClusterAdaptor, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
...@@ -30,10 +30,10 @@ __global__ void ...@@ -30,10 +30,10 @@ __global__ void
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -44,10 +44,10 @@ __global__ void ...@@ -44,10 +44,10 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_k0_m0_m1_k1_grid_desc, a_grid_desc_k0_m0_m1_k1,
b_k0_n0_n1_k1_grid_desc, b_grid_desc_k0_n0_n1_k1,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_desc_m0_m10_m11_n0_n10_n11,
cblockid_to_m0_n0_block_cluster_adaptor, block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -57,12 +57,12 @@ template <index_t BlockSize, ...@@ -57,12 +57,12 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AGridDesc_K0_M_K1,
typename BK0NK1GridDesc, typename BGridDesc_K0_N_K1,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlockM1, index_t MPerBlock,
index_t NPerBlockN1, index_t NPerBlock,
index_t KPerBlock, index_t K0PerBlock,
index_t M1PerThreadM111, index_t M1PerThreadM111,
index_t N1PerThreadN111, index_t N1PerThreadN111,
index_t KPerThread, index_t KPerThread,
...@@ -93,7 +93,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -93,7 +93,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -102,112 +102,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -102,112 +102,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CMNGridDesc& c_m_n_grid_desc) const CMNGridDesc& c_grid_desc_m_n)
{ {
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{ {
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop; return has_main_k_block_loop;
} }
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{ {
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop; return has_double_tail_k_block_loop;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{ {
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc = const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_k0_m_k1_grid_desc, transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc; return a_grid_desc_k0_m0_m1_k1;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{ {
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc = const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_k0_n_k1_grid_desc, transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc; return b_grid_desc_k0_n0_n1_k1;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CMNGridDesc& c_grid_desc_m_n)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlockN1>{}; constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
...@@ -222,41 +222,41 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -222,41 +222,41 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_m_n_grid_desc, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc; return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) MakeDefaultBlock2CTileMap(const CMNGridDesc& c_grid_desc_m_n)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlockN1>{}; constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
const auto cblockid_to_m0_n0_block_cluster_adaptor = const auto block_2_ctile_map =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return cblockid_to_m0_n0_block_cluster_adaptor; return block_2_ctile_map;
} }
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); using CGridDesc_M0_M10_M11_N0_N10_N11 =
using CBlockIdToM0N0BlockClusterAdaptor = decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CMNGridDesc{}));
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
...@@ -264,24 +264,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -264,24 +264,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx =
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
...@@ -293,28 +292,28 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -293,28 +292,28 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() && b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!"); "wrong!");
...@@ -322,14 +321,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -322,14 +321,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
remove_reference_t<decltype(a_k0_m0_m1_k1_grid_desc)>, remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
...@@ -337,23 +336,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -337,23 +336,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, false,
true>(a_k0_m0_m1_k1_grid_desc, true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0), make_multi_index(0, im0, 0, 0),
a_k0_m0_m1_k1_block_desc, a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
remove_reference_t<decltype(b_k0_n0_n1_k1_grid_desc)>, remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
...@@ -361,16 +360,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -361,16 +360,16 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, false,
true>(b_k0_n0_n1_k1_grid_desc, true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0), make_multi_index(0, in0, 0, 0),
b_k0_n0_n1_k1_block_desc, b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS // b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
...@@ -391,58 +390,58 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -391,58 +390,58 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc), decltype(c_thread_desc_m10_m11_n10_n11),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{} decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc, .Run(c_thread_desc_m10_m11_n10_n11,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
FloatAcc{0}); FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -451,76 +450,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -451,76 +450,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf, a_block_even_buf,
b_block_even_buf, b_block_even_buf,
c_thread_buf); c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
...@@ -528,12 +527,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -528,12 +527,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
...@@ -549,8 +548,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -549,8 +548,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<1, Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I0],
...@@ -563,7 +562,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -563,7 +562,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc, true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0, make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
...@@ -571,10 +570,10 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -571,10 +570,10 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}} ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc, .Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf); c_grid_buf);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment