Commit 840cba8e authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/moe

parents bf8e6de7 73b67f29
......@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return false;
if constexpr(!((NDimSpatial == 1 &&
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
(is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 2 &&
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 3 &&
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))))
{
return false;
}
......
......@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
......
......@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return false;
}
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
......
......@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
......
......@@ -925,7 +925,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
......@@ -941,7 +941,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
......@@ -960,7 +960,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>() &&
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;
......
......@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
......
......@@ -12,7 +12,7 @@ namespace device {
// 1d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NWGK_GKXC_NWGC()
constexpr bool is_NWGC_GKXC_NWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
......@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNWK_GKXC_GNWC()
constexpr bool is_GNWC_GKXC_GNWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
......@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGK_GKYXC_NHWGC()
constexpr bool is_NHWGC_GKYXC_NHWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
......@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNHWK_GKYXC_GNHWC()
constexpr bool is_GNHWC_GKYXC_GNHWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_GKYXC_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
// 3d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
......@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
......@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC()
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
{
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>();
return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC()
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
{
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>();
return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static constexpr const char* name = "NDHWGC";
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct NGCW : public BaseTensorLayout
{
static constexpr const char* name = "NGCW";
};
struct NGCHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCHW";
};
struct NGCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCDHW";
};
// input tensor
// strided layout
struct G_NW_C : public BaseTensorLayout
......@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static constexpr const char* name = "NDHWGK";
};
struct NGKW : public BaseTensorLayout
{
static constexpr const char* name = "NGKW";
};
struct NGKHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKHW";
};
struct NGKDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKDHW";
};
// output tensor
// strided layout
struct G_NW_K : public BaseTensorLayout
......
......@@ -41,6 +41,55 @@ __global__ void
elementwise_op);
}
template <typename GridwiseElementwiseFunctor,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InDataTypePointerTuple p_in_global_tuple_a,
const InDataTypePointerTuple p_in_global_tuple_b,
const OutDataTypePointerTuple p_out_global_tuple_a,
const OutDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
const index_t a_grid_size)
{
if(get_block_1d_id() < a_grid_size)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_tuple_a,
p_out_global_tuple_a,
block_2_tile_map_a,
elementwise_op,
get_block_1d_id());
}
else
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_tuple_b,
p_out_global_tuple_b,
block_2_tile_map_b,
elementwise_op,
get_block_1d_id() - a_grid_size);
}
}
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
......@@ -133,7 +182,8 @@ struct GridwiseElementwise
const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op)
const ElementwiseOperation& elementwise_op,
const index_t block_id = get_block_1d_id())
{
constexpr auto src_datas = generate_tuple(
......@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number<NumOutput>{});
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
......
......@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
using NGKW = ck::tensor_layout::convolution::NGKW;
using NGKHW = ck::tensor_layout::convolution::NGKHW;
using NGKDHW = ck::tensor_layout::convolution::NGKDHW;
//
using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
......@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
using NGCW = ck::tensor_layout::convolution::NGCW;
using NGCHW = ck::tensor_layout::convolution::NGCHW;
using NGCDHW = ck::tensor_layout::convolution::NGCDHW;
//
using G_K = ck::tensor_layout::convolution::G_K;
using GK_Tuple = ck::Tuple<G_K>;
......
......@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
// clang-format on
>;
// NGCHW requires transpose, we use vector loads and stores params for them
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 1, 2>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8 ,1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
op_ptrs);
}
#endif
}
}
......@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
op_ptrs);
}
#endif
}
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NGKDHW>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
op_ptrs);
}
#endif
}
}
......
......@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
......@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
return {0, 1, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKW>)
{
return {1, 0, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKHW>)
{
return {1, 0, 2, 3, 4};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGCDHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NGKDHW>)
{
return {1, 0, 2, 3, 4, 5};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCYX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKHW>)
......@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NGCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
......@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NGKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::GNDHWK>)
......
......@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
)
if(DL_KERNELS)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
2,
NGCHW,
GKYXC,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NGCHW,
GKYXC,
NGKHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
2,
NGCHW,
GKYXC,
NGKHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -8,6 +8,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp
)
if(DL_KERNELS)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment