Commit 5e95b637 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Comments, adjustment

parent 2a768425
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -16,6 +16,7 @@ using OutElementOp = PassThrough; ...@@ -16,6 +16,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
// clang-format off
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
NDimSpatial, NDimSpatial,
ck::tensor_layout::convolution::GNDHWC, ck::tensor_layout::convolution::GNDHWC,
...@@ -52,11 +53,11 @@ using DeviceConvBwdWeightInstance = ...@@ -52,11 +53,11 @@ using DeviceConvBwdWeightInstance =
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
4, 4, // CShuffleMXdlPerWavePerShuffle
2, 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, S<1, 32, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1>; 1>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType, InDataType,
......
...@@ -18,6 +18,7 @@ using OutElementOp = PassThrough; ...@@ -18,6 +18,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
// clang-format off
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
...@@ -71,6 +72,7 @@ using DeviceConvBwdWeightInstance = ...@@ -71,6 +72,7 @@ using DeviceConvBwdWeightInstance =
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ComputeTypeA, // ComputeTypeA ComputeTypeA, // ComputeTypeA
ComputeTypeB>; // ComputeTypeB ComputeTypeB>; // ComputeTypeB
// clang-format on
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment