Commit bd4c668e authored by Chao Liu's avatar Chao Liu
Browse files

experiment dummy static transform

parent 435f5f91
#ifndef CK_DUMMY_STATIC_TRANSFORM_HPP
#define CK_DUMMY_STATIC_TRANSFORM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
struct DummyStaticTransform
{
__device__ void Run(Float* const __restrict__ p_in_global,
Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// weight tensor
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input
const index_t k0 = p_in_global[get_thread_local_1d_id()];
const index_t n0 = p_in_global[get_thread_local_1d_id()];
auto coord = typename TensorCoordinate<decltype(in_gemmk_gemmn_global_desc)>::type(k0, n0);
if(get_block_1d_id() < coord.GetOffset())
{
for(index_t k = 0; k < 1; ++k)
{
for(index_t n = 0; n < 4; ++n)
{
auto tmp = coord + Array<index_t, 2>{k, n};
Float value = 1;
transfer_data<Float,
1,
AddressSpace::Vgpr,
AddressSpace::Global,
InMemoryDataOperation::Set,
1,
1>(&value,
0,
true,
1,
p_in_global,
tmp.GetOffset(),
tmp.IsOffsetValidAssumingUpperIndexIsValid(),
in_gemmk_gemmn_global_desc.GetElementSpace());
}
}
}
}
};
} // namespace ck
#endif
...@@ -196,6 +196,7 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave, ...@@ -196,6 +196,7 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
#if 1 // debug
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK #if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return __llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data, return __llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data,
0, 0,
...@@ -209,6 +210,12 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave, ...@@ -209,6 +210,12 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
return __llvm_amdgcn_buffer_load_f32( return __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_addr_base + src_thread_addr_offset, false, false); src_wave_buffer_resource.data, 0, src_addr_base + src_thread_addr_offset, false, false);
#endif #endif
#else
return src_thread_data_valid
? __llvm_amdgcn_buffer_load_f32(
src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false)
: 0;
#endif
} }
template <> template <>
...@@ -570,6 +577,7 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread, ...@@ -570,6 +577,7 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
#if 1 // debug
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK #if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32(*p_src_thread, __llvm_amdgcn_buffer_store_f32(*p_src_thread,
dst_wave_buffer_resource.data, dst_wave_buffer_resource.data,
...@@ -587,6 +595,13 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread, ...@@ -587,6 +595,13 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
false, false,
false); false);
#endif #endif
#else
if(dst_thread_data_valid)
{
__llvm_amdgcn_buffer_store_f32(
*p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false);
}
#endif
} }
template <> template <>
......
...@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -290,7 +290,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -290,7 +290,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 128, 64x128x8 // cdata = 64, BlockSize = 128, 64x128x8
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_static_transform.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_dummy_transform(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t BlockSize = 256;
constexpr index_t GridSize = 1;
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using dummy_transform = DummyStaticTransform<GridSize,
BlockSize,
float,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads>;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
launch_kernel(run_gridwise_operation<dummy_transform,
float* const __restrict__,
float* const __restrict__,
float* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<float*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<float*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<float*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
...@@ -52,7 +52,7 @@ int main(int argc, char* argv[]) ...@@ -52,7 +52,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
...@@ -245,7 +245,7 @@ int main(int argc, char* argv[]) ...@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 0 #if 1
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
...@@ -256,17 +256,17 @@ int main(int argc, char* argv[]) ...@@ -256,17 +256,17 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw, out_nkhw,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
if(do_verification) if(do_verification)
{ {
......
...@@ -14,26 +14,27 @@ ...@@ -14,26 +14,27 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_transform.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 0
// 1x1, 17x17 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 192;
constexpr index_t HI = 17; constexpr index_t HI = 71;
constexpr index_t WI = 17; constexpr index_t WI = 71;
constexpr index_t K = 256; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 3;
constexpr index_t X = 1; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
// 1x1, 8x8 // 1x1, 8x8
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -109,7 +110,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +110,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 1x7, 17x17 // 1x7, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -141,7 +142,6 @@ int main(int argc, char* argv[]) ...@@ -141,7 +142,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 3x3, 147x147 // 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 32; constexpr index_t C = 32;
constexpr index_t HI = 147; constexpr index_t HI = 147;
...@@ -157,7 +157,6 @@ int main(int argc, char* argv[]) ...@@ -157,7 +157,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
// 3x3, 149x149 // 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 32; constexpr index_t C = 32;
constexpr index_t HI = 149; constexpr index_t HI = 149;
...@@ -201,7 +200,7 @@ int main(int argc, char* argv[]) ...@@ -201,7 +200,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 35x35, stride 2 // 3x3, 35x35, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 288;
...@@ -244,21 +243,6 @@ int main(int argc, char* argv[]) ...@@ -244,21 +243,6 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>; using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>; using RightPads = Sequence<1, 0>;
#elif 0 #elif 0
...@@ -278,7 +262,6 @@ int main(int argc, char* argv[]) ...@@ -278,7 +262,6 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
// 7x1, 73x73 // 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
constexpr index_t HI = 73; constexpr index_t HI = 73;
...@@ -352,10 +335,10 @@ int main(int argc, char* argv[]) ...@@ -352,10 +335,10 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 192;
constexpr index_t HI = 28; constexpr index_t HI = 28;
constexpr index_t WI = 28; constexpr index_t WI = 28;
constexpr index_t K = 128; constexpr index_t K = 128;
...@@ -367,7 +350,7 @@ int main(int argc, char* argv[]) ...@@ -367,7 +350,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 3x3, 14x14 // 3x3, 14x14
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -382,7 +365,7 @@ int main(int argc, char* argv[]) ...@@ -382,7 +365,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 1x1, 56x56, stride 2 // 1x1, 56x56, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -472,7 +455,7 @@ int main(int argc, char* argv[]) ...@@ -472,7 +455,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 1x1, 56x56 // 1x1, 56x56
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -487,7 +470,7 @@ int main(int argc, char* argv[]) ...@@ -487,7 +470,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 56x56 // 3x3, 56x56
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -565,7 +548,7 @@ int main(int argc, char* argv[]) ...@@ -565,7 +548,7 @@ int main(int argc, char* argv[])
#endif #endif
} }
#if 1 #if 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -577,7 +560,7 @@ int main(int argc, char* argv[]) ...@@ -577,7 +560,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -589,6 +572,18 @@ int main(int argc, char* argv[]) ...@@ -589,6 +572,18 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1
device_dummy_transform(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif #endif
if(do_verification) if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment