Commit 708fd81f authored by aska-0096's avatar aska-0096
Browse files

batched gemm, conv, skip b lds

parent 060c4f3a
......@@ -74,8 +74,8 @@ using DeviceOpInstanceKKNN =
8,
16,
16,
1,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN =
true,
1,
1,
S<1, 128, 1, 2>,
S<1, 16, 1, 16>,
8>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
......
......@@ -58,8 +58,8 @@ using DeviceConvFwdInstance =
8, // K1
16, // MPerWMMA
16, // NPerWMMA
1, // MRepeat
8, // NRepeat
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
true, // BBlockLdsExtraN
1,
1,
S<1, 128, 1, 2>,
S<1, 32, 1, 8>,
8>;
template <ck::index_t NDimSpatial>
......
......@@ -228,7 +228,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_K0_N_K1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
......@@ -287,6 +287,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
if constexpr(BEnableLds)
{
const index_t K0 = K / K1;
return transform_tensor_descriptor(
......@@ -296,6 +299,22 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
......@@ -504,7 +523,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
};
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1({},{}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({},{}));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
......@@ -517,7 +536,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType,
// InMemory Data Descriptor
AGridDesc,
BGridDesc_K0_N_K1,
BGridDesc,
DsGridDesc_M_N,
EGridDesc_M_N,
// ElementwiseOp Family
......@@ -586,8 +605,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
a_grid_desc_{},
b_grid_desc_{},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{},
ds_grid_desc_g_m_n_{
......@@ -620,8 +639,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
});
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
......@@ -660,8 +679,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType* p_e_grid_;
// Tensor Descriptors
AGridDesc a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
......@@ -714,8 +733,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
const auto K = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
}();
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
......@@ -727,7 +755,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
typename GridwiseOp::DsGridPointer,
EDataType,
DeviceOp::AGridDesc,
DeviceOp::BGridDesc_K0_N_K1,
DeviceOp::BGridDesc,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
......@@ -747,8 +775,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg.p_ds_grid_,
arg.p_e_grid_,
G,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
......@@ -797,8 +825,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_))
......@@ -816,7 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if constexpr(ABlockTransferSrcVectorDim == 1)
{
if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access A-m check failure\n");
return false;
......@@ -825,7 +853,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access A-k check failure\n");
return false;
......@@ -836,7 +864,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access B-n check failure\n");
return false;
......@@ -845,7 +873,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
else
{
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
printf("DeviceOp: Vector Access B-k check failure\n");
return false;
......
......@@ -245,7 +245,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const auto wei_gemmnraw_gemmkraw_desc =
......@@ -257,16 +257,35 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
assert(K % K1 == 0);
const auto BK1 = K1;
const auto BK0 = K / BK1;
if constexpr(BEnableLds)
{
const index_t K0 = K / K1;
return transform_tensor_descriptor(wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
return transform_tensor_descriptor(
wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
}
template <typename ELay>
static auto
......@@ -302,7 +321,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
......@@ -315,7 +334,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
EDataType,
// InMemory Data Descriptor
AGridDesc,
BGridDesc_BK0_N_BK1,
BGridDesc,
DsGridDesc_M_N,
EGridDesc_M_N,
// ElementwiseOp Family
......@@ -394,7 +413,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -404,7 +423,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths,
b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
......@@ -457,8 +476,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
std::cout << "A[M, K]: " << a_grid_desc_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
......@@ -477,8 +496,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc a_grid_desc;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc a_grid_desc_;
BGridDesc b_grid_desc_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -525,8 +544,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
const auto K =
arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
const auto K = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
}();
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
......@@ -541,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::BGridDesc,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
......@@ -561,8 +589,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_,
arg.b_grid_desc_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
......@@ -731,8 +759,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
// check Gridwise GEMM
return GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b_grid_desc_bk0_n_bk1_,
return GridwiseOp::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment