Commit 3474c777 authored by Chao Liu's avatar Chao Liu
Browse files

add gemm padding to convnd

parent 7cc806d8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#include "convnd_fwd_common.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_fwd_nwc_kxc_nwk_xdl.hpp"
......@@ -20,10 +18,10 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
#if 0
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
#if 0
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwcKxcNwk_Xdl<
NDimSpatial, //
......@@ -63,6 +61,11 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
#else
using CShuffleDataType = ck::half_t;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial>
using DeviceConvNDFwdInstance =
ck::tensor_operation::device::DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle<
......@@ -76,7 +79,8 @@ using DeviceConvNDFwdInstance =
InElementOp, // Input Elementwise Operation
WeiElementOp, // Weights Elementwise Operation
OutElementOp, // Output Elementwise Operation
ConvFwdDefault, // ConvForwardSpecialization
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment