Commit ebb5522c authored by Mateusz Ozga's avatar Mateusz Ozga Committed by root
Browse files

Apply cshuffle to bwd_weight_cshuffle operator

parent fdfe2102
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -39,33 +39,33 @@ using DeviceConvBwdWeightInstance = ...@@ -39,33 +39,33 @@ using DeviceConvBwdWeightInstance =
WeiElementOp, // WeiElementwiseOperation WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize 64, // BlockSize
128, // MPerBlock 16, // MPerBlock
128, // NPerBlock 16, // NPerBlock
4, // K0PerBlock 32, // K0PerBlock
8, // K1 8, // K1
32, // MPerXdl 16, // MPerXdl
32, // NPerXdl 16, // NPerXdl
2, // MXdlPerWave 1, // MXdlPerWave
2, // NXdlPerWave 1, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl 1>; // CBlockTransferScalarPerVector_NWaveNPerXdl
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -41,26 +41,26 @@ using DeviceConvBwdWeightInstance = ...@@ -41,26 +41,26 @@ using DeviceConvBwdWeightInstance =
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
4, // K0PerBlock 32, // K0PerBlock
8, // K1 8, // K1
32, // MPerXdl 32, // MPerXdl
32, // NPerXdl 32, // NPerXdl
2, // MXdlPerWave 2, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1 2, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1 2, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
......
...@@ -40,35 +40,38 @@ using DeviceConvBwdWeightInstance = ...@@ -40,35 +40,38 @@ using DeviceConvBwdWeightInstance =
WeiElementOp, // WeiElementwiseOperation WeiElementOp, // WeiElementwiseOperation
OutElementOp, // OutElementwiseOperation OutElementOp, // OutElementwiseOperation
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
256, // BlockSize 64, // BlockSize
128, // MPerBlock 16, // MPerBlock
128, // NPerBlock 16, // NPerBlock
4, // K0PerBlock 32, // K0PerBlock
8, // K1 8, // K1
32, // MPerXdl 16, // MPerXdl
32, // NPerXdl 16, // NPerXdl
2, // MXdlPerWave 1, // MXdlPerWave
2, // NXdlPerWave 1, // NXdlPerWave
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTransferDstScalarPerVector_K1
true, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
true, // BBlockLdsAddExtraN false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl 2, // CBlockTransferScalarPerVector_NWaveNPerXdl
ComputeTypeA, // ComputeTypeA ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ComputeTypeB>; // ComputeTypeB ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
1, // NumGroupsToMerge
ComputeTypeA, // ComputeTypeA
ComputeTypeB>; // ComputeTypeB
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
......
...@@ -34,6 +34,94 @@ struct TransformConvBwdWeightToGemmV2 ...@@ -34,6 +34,94 @@ struct TransformConvBwdWeightToGemmV2
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
constexpr static auto
make_out_grid_desc(const index_t N,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& output_strides)
{
const index_t BatchStride = output_strides[0];
const index_t WoStride = output_strides[3];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Wo, NumGroupsToMerge, K),
make_tuple(WoStride, BatchStride, KStride));
}
template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
constexpr static auto
make_in_grid_desc(const index_t N,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& input_strides)
{
const index_t BatchStride = input_strides[0];
const index_t NStride = input_strides[1];
const index_t WiStride = input_strides[3];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Wi, NumGroupsToMerge, C),
make_tuple(WiStride, BatchStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Wi, NumGroupsToMerge, C),
make_tuple(NStride, WiStride, BatchStride, CStride));
}
}
template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const index_t K,
const index_t X,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
const auto XStride = weights_strides[3];
const auto BatchStride = weights_strides[0];
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K, X, 1, C),
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K),
make_pass_through_transform(X),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K),
make_pass_through_transform(X),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)),
make_merge_transform(make_tuple(X, NumGroupsToMerge, C))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false> template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
constexpr static auto constexpr static auto
make_out_grid_desc(const index_t N, make_out_grid_desc(const index_t N,
...@@ -221,6 +309,187 @@ struct TransformConvBwdWeightToGemmV2 ...@@ -221,6 +309,187 @@ struct TransformConvBwdWeightToGemmV2
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const index_t N,
const index_t K,
const index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_strides,
const std::array<index_t, NDimSpatial + 3>& weights_strides,
const std::array<index_t, NDimSpatial + 3>& output_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
{
using namespace ck;
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0];
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const index_t GemmKTotal = N * Wo;
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
}
else
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(X, NumGroupsToMerge, C)),
make_merge_transform(make_tuple(N, Wo))),
make_tuple(Sequence<1, 3, 4>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
} // function end
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false> template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const index_t N, const index_t N,
......
...@@ -276,7 +276,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -276,7 +276,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> && is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>) is_same_v<ComputeTypeB, float>)
{ {
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev5_instances(
op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
...@@ -284,7 +291,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -284,7 +291,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> && is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>) is_same_v<ComputeTypeB, half_t>)
{ {
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instances(
op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -293,7 +307,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -293,7 +307,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -309,7 +329,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -309,7 +329,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> && is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>) is_same_v<ComputeTypeB, float>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -318,7 +344,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -318,7 +344,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> && is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>) is_same_v<ComputeTypeB, half_t>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev2_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -328,7 +360,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -328,7 +360,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -341,7 +379,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -341,7 +379,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> && is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>) is_same_v<ComputeTypeB, float>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -350,7 +394,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -350,7 +394,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> && is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>) is_same_v<ComputeTypeB, half_t>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
op_ptrs); op_ptrs);
...@@ -366,7 +416,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -366,7 +416,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
if constexpr(is_same_v<InDataType, ck::bhalf_t> && if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
...@@ -375,7 +431,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -375,7 +431,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
op_ptrs); op_ptrs);
...@@ -429,7 +491,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -429,7 +491,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> && is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>) is_same_v<ComputeTypeB, float>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -438,7 +506,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -438,7 +506,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> && is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>) is_same_v<ComputeTypeB, half_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -448,7 +522,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -448,7 +522,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -461,7 +541,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -461,7 +541,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> && is_same_v<OutDataType, float> && is_same_v<ComputeTypeA, float> &&
is_same_v<ComputeTypeB, float>) is_same_v<ComputeTypeB, float>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
...@@ -470,7 +556,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -470,7 +556,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> && is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
is_same_v<ComputeTypeB, half_t>) is_same_v<ComputeTypeB, half_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
op_ptrs); op_ptrs);
...@@ -486,7 +578,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -486,7 +578,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
if constexpr(is_same_v<InDataType, ck::bhalf_t> && if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
...@@ -495,7 +593,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -495,7 +593,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ComputeTypeA, ck::bhalf_t> && is_same_v<ComputeTypeA, ck::bhalf_t> &&
is_same_v<ComputeTypeB, ck::bhalf_t>) is_same_v<ComputeTypeB, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
op_ptrs); op_ptrs);
...@@ -510,7 +614,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -510,7 +614,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, bf8_t> && is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, bf8_t> &&
is_same_v<ComputeTypeB, f8_t>) is_same_v<ComputeTypeB, f8_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_default_pipev5_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_pad0_pipev2_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_pad0_pipev5_instances(
op_ptrs); op_ptrs);
} }
#endif #endif
......
# ONLY XDL_AND_DL_KERNELS # ONLY XDL_AND_DL_KERNELS
set(GROUPED_CONV1D_BWD_WEIGHT set(GROUPED_CONV1D_BWD_WEIGHT
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp) xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instance.cpp
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev5_instance.cpp)
if(DL_KERNELS) if(DL_KERNELS)
list(APPEND GROUPED_CONV1D_BWD_WEIGHT list(APPEND GROUPED_CONV1D_BWD_WEIGHT
...@@ -15,3 +24,4 @@ if(DL_KERNELS) ...@@ -15,3 +24,4 @@ if(DL_KERNELS)
endif() endif()
add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT}) add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT})
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC, GNWC,
GKXC, GKXC,
...@@ -21,7 +21,6 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta ...@@ -21,7 +21,6 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
// 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
...@@ -29,16 +28,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta ...@@ -29,16 +28,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
GNWC, GNWC,
GKXC, GKXC,
GNWK, GNWK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault,
// 2. Filter1x1Stride1Pad0 BlockGemmPipelineScheduler::Intrawave,
add_device_operation_instances( BlockGemmPipelineVersion::v2>{});
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_pad0_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC, GNWC,
GKXC, GKXC,
...@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( ...@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
1, 1,
GNWC, GNWC,
GKXC, GKXC,
GNWK, GNWK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
} }
} // namespace instance } // namespace instance
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_pad0_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightFilter1x1Stride1Pad0,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...@@ -9,7 +9,7 @@ namespace tensor_operation { ...@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_pad0_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC, GNWC,
GKXC, GKXC,
...@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( ...@@ -21,22 +21,15 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<1,
GNWC,
GKXC,
GNWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
1, 1,
GNWC, GNWC,
GKXC, GKXC,
GNWK, GNWK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment