Commit 5e95b637 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Comments, adjustment

parent 2a768425
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......@@ -16,6 +16,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
// clang-format off
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
NDimSpatial,
ck::tensor_layout::convolution::GNDHWC,
......@@ -52,11 +53,11 @@ using DeviceConvBwdWeightInstance =
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
S<1, 32, 1, 8>,
1>;
4, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InDataType,
......
......@@ -18,6 +18,7 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance =
// clang-format off
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
......@@ -71,6 +72,7 @@ using DeviceConvBwdWeightInstance =
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ComputeTypeA, // ComputeTypeA
ComputeTypeB>; // ComputeTypeB
// clang-format on
template <ck::index_t NDimSpatial>
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment