Commit 708fd81f authored by aska-0096's avatar aska-0096
Browse files

batched gemm, conv, skip b lds

parent 060c4f3a
...@@ -74,8 +74,8 @@ using DeviceOpInstanceKKNN = ...@@ -74,8 +74,8 @@ using DeviceOpInstanceKKNN =
8, 8,
16, 16,
16, 16,
1,
8, 8,
1,
S<4, 64, 1>, S<4, 64, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN = ...@@ -92,7 +92,7 @@ using DeviceOpInstanceKKNN =
true, true,
1, 1,
1, 1,
S<1, 128, 1, 2>, S<1, 16, 1, 16>,
8>; 8>;
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
......
...@@ -58,8 +58,8 @@ using DeviceConvFwdInstance = ...@@ -58,8 +58,8 @@ using DeviceConvFwdInstance =
8, // K1 8, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
1, // MRepeat 4, // MRepeat
8, // NRepeat 2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
...@@ -76,7 +76,7 @@ using DeviceConvFwdInstance = ...@@ -76,7 +76,7 @@ using DeviceConvFwdInstance =
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
1, 1,
1, 1,
S<1, 128, 1, 2>, S<1, 32, 1, 8>,
8>; 8>;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
......
...@@ -228,7 +228,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -228,7 +228,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
} }
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static auto MakeBGridDescriptor_K0_N_K1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec) const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{ {
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
...@@ -287,14 +287,33 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -287,14 +287,33 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0); assert(K % K1 == 0);
const index_t K0 = K / K1;
if constexpr(BEnableLds)
return transform_tensor_descriptor( {
b_grid_desc_n_k, const index_t K0 = K / K1;
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), return transform_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{}), b_grid_desc_n_k,
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
} }
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
...@@ -504,7 +523,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -504,7 +523,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}; };
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{})); using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({},{}));
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1({},{})); using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({},{}));
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
...@@ -517,7 +536,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -517,7 +536,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType, EDataType,
// InMemory Data Descriptor // InMemory Data Descriptor
AGridDesc, AGridDesc,
BGridDesc_K0_N_K1, BGridDesc,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
// ElementwiseOp Family // ElementwiseOp Family
...@@ -586,8 +605,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -586,8 +605,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_k0_m_k1_{}, a_grid_desc_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_{},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{}, e_grid_desc_m_n_{},
ds_grid_desc_g_m_n_{ ds_grid_desc_g_m_n_{
...@@ -620,8 +639,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -620,8 +639,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
}); });
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
ds_grid_desc_m_n_ = ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
...@@ -660,8 +679,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -660,8 +679,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
EDataType* p_e_grid_; EDataType* p_e_grid_;
// Tensor Descriptors // Tensor Descriptors
AGridDesc a_grid_desc_k0_m_k1_; AGridDesc a_grid_desc_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc b_grid_desc_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
DsGridDesc_G_M_N ds_grid_desc_g_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
...@@ -714,8 +733,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -714,8 +733,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
const auto K = const auto K = [&]() {
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
}();
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
...@@ -727,7 +755,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -727,7 +755,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
typename GridwiseOp::DsGridPointer, typename GridwiseOp::DsGridPointer,
EDataType, EDataType,
DeviceOp::AGridDesc, DeviceOp::AGridDesc,
DeviceOp::BGridDesc_K0_N_K1, DeviceOp::BGridDesc,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
...@@ -747,8 +775,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -747,8 +775,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
G, G,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
...@@ -797,8 +825,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -797,8 +825,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
return false; return false;
} }
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseOp::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
...@@ -816,7 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -816,7 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if constexpr(ABlockTransferSrcVectorDim == 1) if constexpr(ABlockTransferSrcVectorDim == 1)
{ {
if(!(arg.a_mz_stride_ == 1 && if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access A-m check failure\n"); printf("DeviceOp: Vector Access A-m check failure\n");
return false; return false;
...@@ -825,7 +853,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -825,7 +853,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
else else
{ {
if(!(arg.a_kz_stride_ == 1 && if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access A-k check failure\n"); printf("DeviceOp: Vector Access A-k check failure\n");
return false; return false;
...@@ -836,7 +864,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -836,7 +864,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if constexpr(BBlockTransferSrcVectorDim == 1) if constexpr(BBlockTransferSrcVectorDim == 1)
{ {
if(!(arg.b_nz_stride_ == 1 && if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access B-n check failure\n"); printf("DeviceOp: Vector Access B-n check failure\n");
return false; return false;
...@@ -845,7 +873,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -845,7 +873,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
else else
{ {
if(!(arg.b_kz_stride_ == 1 && if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{ {
printf("DeviceOp: Vector Access B-k check failure\n"); printf("DeviceOp: Vector Access B-k check failure\n");
return false; return false;
......
...@@ -245,7 +245,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -245,7 +245,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides) const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
...@@ -257,15 +257,34 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -257,15 +257,34 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const auto N = wei_gemmn_gemmk_desc.GetLength(I0); const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
const auto K = wei_gemmn_gemmk_desc.GetLength(I1); const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
assert(K % K1 == 0);
if constexpr(BEnableLds)
{
const index_t K0 = K / K1;
const auto BK1 = K1; return transform_tensor_descriptor(
const auto BK0 = K / BK1; wei_gemmn_gemmk_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = WmmaK / K1;
const auto B_KWmma = K / WmmaK;
return transform_tensor_descriptor(wei_gemmn_gemmk_desc, const auto N0 = N / NPerBlock;
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)), return transform_tensor_descriptor(
make_tuple(Sequence<1>{}, Sequence<0>{}), wei_gemmn_gemmk_desc,
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(make_unmerge_transform(make_tuple(B_KWmma, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 5>{}, Sequence<1, 2, 4>{}));
}
} }
template <typename ELay> template <typename ELay>
...@@ -302,7 +321,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -302,7 +321,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {})); using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
...@@ -315,7 +334,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -315,7 +334,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
EDataType, EDataType,
// InMemory Data Descriptor // InMemory Data Descriptor
AGridDesc, AGridDesc,
BGridDesc_BK0_N_BK1, BGridDesc,
DsGridDesc_M_N, DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
// ElementwiseOp Family // ElementwiseOp Family
...@@ -394,7 +413,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -394,7 +413,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)}, e_g_n_k_wos_strides)},
a_grid_desc{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths, a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -404,7 +423,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -404,7 +423,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)}, b_g_k_c_xs_strides)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
...@@ -457,8 +476,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -457,8 +476,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
void Print() const void Print() const
{ {
std::cout << "A[M, K]: " << a_grid_desc << std::endl; std::cout << "A[M, K]: " << a_grid_desc_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_bk0_n_bk1_ << std::endl; std::cout << "B[N, K]: " << b_grid_desc_ << std::endl;
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
...@@ -477,8 +496,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -477,8 +496,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc a_grid_desc; AGridDesc a_grid_desc_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc b_grid_desc_;
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -525,8 +544,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -525,8 +544,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
const auto K = const auto K = [&]() {
arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2); if constexpr(AEnableLds)
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
}
else
{
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
arg.a_grid_desc_.GetLength(I5);
}
}();
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
...@@ -541,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -541,7 +569,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc, DeviceOp::AGridDesc,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc,
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
...@@ -561,8 +589,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -561,8 +589,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc, arg.a_grid_desc_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_, arg.block_2_etile_map_,
...@@ -731,8 +759,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -731,8 +759,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
// check Gridwise GEMM // check Gridwise GEMM
return GridwiseOp::CheckValidity(arg.a_grid_desc, return GridwiseOp::CheckValidity(arg.a_grid_desc_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment