Commit 5e95b637 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Comments, adjustment

parent 2a768425
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -16,6 +16,7 @@ using OutElementOp = PassThrough; ...@@ -16,6 +16,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
// clang-format off
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
NDimSpatial, NDimSpatial,
ck::tensor_layout::convolution::GNDHWC, ck::tensor_layout::convolution::GNDHWC,
...@@ -52,11 +53,11 @@ using DeviceConvBwdWeightInstance = ...@@ -52,11 +53,11 @@ using DeviceConvBwdWeightInstance =
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
4, 4, // CShuffleMXdlPerWavePerShuffle
2, 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, S<1, 32, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1>; 1>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType, InDataType,
......
...@@ -18,6 +18,7 @@ using OutElementOp = PassThrough; ...@@ -18,6 +18,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
// clang-format off
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
...@@ -54,23 +55,24 @@ using DeviceConvBwdWeightInstance = ...@@ -54,23 +55,24 @@ using DeviceConvBwdWeightInstance =
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cppferDstScalarPerVector_K1 4, // ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cppferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
false, // BBlockLdsAddExtraN false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl 2, // CBlockTransferScalarPerVector_NWaveNPerXdl
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ComputeTypeA, // ComputeTypeA ComputeTypeA, // ComputeTypeA
ComputeTypeB>; // ComputeTypeB ComputeTypeB>; // ComputeTypeB
// clang-format on
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment