"src/include/threadwise_tensor_slice_op.hpp" did not exist on "e72eece8fcf79d1d3a958089fca1f02bfb71b777"
Commit 3a558e59 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move AK0/BK0 compute logic into GridwiseGemm

parent 7a62d4a7
...@@ -82,8 +82,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -82,8 +82,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static auto static auto MakeAGridDescriptor_AK0_M_AK1(
MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA) index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -100,10 +100,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -100,10 +100,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0);
const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_tuple(make_right_pad_transform(M, MPad - M),
...@@ -124,10 +120,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -124,10 +120,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
...@@ -141,10 +133,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -141,10 +133,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = KPad / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw, a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
...@@ -163,10 +151,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -163,10 +151,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else else
{ {
// not pad M or K // not pad M or K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
...@@ -178,8 +162,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -178,8 +162,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
} }
static auto static auto MakeBGridDescriptor_BK0_N_BK1(
MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB) index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -196,10 +180,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -196,10 +180,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0);
const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N), make_tuple(make_right_pad_transform(N, NPad - N),
...@@ -220,10 +200,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -220,10 +200,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
...@@ -237,10 +213,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -237,10 +213,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GemmSpec == GemmSpecialization::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = KPad / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
...@@ -259,10 +231,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -259,10 +231,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
else else
{ {
// not pad N or K // not pad N or K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
...@@ -325,8 +293,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -325,8 +293,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// GridwiseGemm // GridwiseGemm
...@@ -338,6 +306,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -338,6 +306,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
...@@ -399,13 +368,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -399,13 +368,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
K, K,
GridwiseGemm::CalculateKPadded(K), GridwiseGemm::CalculateKPadded(K),
StrideA)}, StrideA,
GridwiseGemm::CalculateAK0(K))},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K,
GridwiseGemm::CalculateKPadded(K), GridwiseGemm::CalculateKPadded(K),
N, N,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
StrideB)}, StrideB,
GridwiseGemm::CalculateBK0(K))},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M),
N, N,
......
...@@ -79,6 +79,7 @@ template <typename FloatAB, ...@@ -79,6 +79,7 @@ template <typename FloatAB,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -150,6 +151,48 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -150,6 +151,48 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return (K + KPerBlock - 1) / KPerBlock * KPerBlock; return (K + KPerBlock - 1) / KPerBlock * KPerBlock;
} }
__host__ __device__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
assert(CalculateKPadded(K) % AK1 == 0);
return CalculateKPadded(K) / AK1;
}
else
{
assert(K % AK1 == 0);
return K / AK1;
}
}
__host__ __device__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
assert(CalculateKPadded(K) % BK1 == 0);
return CalculateKPadded(K) / BK1;
}
else
{
assert(K % BK1 == 0);
return K / BK1;
}
}
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment