"...resnet50_tensorflow.git" did not exist on "00fa8b123a1c2bb7a5fd2fd126793acdbfc31245"
Commit 82a15a27 authored by Jing Zhang's avatar Jing Zhang
Browse files

add xdlops emulation on v100

parent e69b1970
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
class LeftPads,
class RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
class GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
class GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
class GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
class GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmM = K;
constexpr index_t GemmK = C * Y * X;
constexpr index_t GemmN = N * Ho * Wo;
static_assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0,
"wrong! cannot divide work evenly among block");
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2];
constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3];
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalCXdlops_v1<
GridSize,
BlockSize,
Float,
AccDataType,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm_gemmn_global_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
Sequence<0, 1>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
InMemoryDataOperation::Set>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class Float,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
static constexpr index_t WaveSize = 64;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetOutputLayout(); }
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * WaveSize,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize\n");
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.template Run<M, N, K>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.GetBeginOfThreadBlk(i);
const index_t col = waveId % GemmNWaves * GemmNPerWave + thread_mtx_on_blk.col;
const index_t row = waveId / GemmNWaves * GemmMPerWave + thread_mtx_on_blk.row;
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
const index_t reg_size = GemmMPerWave * GemmNPerWave / WaveSize;
return make_ConstantMatrixDescriptor_packed(Number<reg_size>{}, Number<1>{});
}
__device__ void XdlopsMatrixCSetZero() const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.SetZeroXdlopsRegs(Number<thread_mtx_size>{});
}
template <class FloatC>
__device__ void XdlopsMatrixCRead(FloatC* __restrict__ p_c_thread) const
{
constexpr auto thread_mtx_size = GemmMPerWave * GemmNPerWave / WaveSize;
XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{}.ReadXdlopsRegs(Number<thread_mtx_size>{}, p_c_thread);
}
};
} // namespace ck
#endif
......@@ -56,8 +56,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
#if 0
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
#endif
const auto thread_cluster_id =
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id());
......@@ -83,6 +85,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
......@@ -92,6 +99,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
}
}
}
template <typename ThreadBufferData, typename BlockDstData>
......@@ -101,6 +109,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
if(get_thread_local_1d_id() < thread_cluster_desc.GetElementSize())
{
// TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation)
{
......@@ -110,6 +123,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
}
}
}
template <typename BlockSrcData, typename BlockDstData>
......
This diff is collapsed.
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
namespace ck {
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x1f32(const float&, const float&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 32; i++)
{
reg_c_[i] += reg_a * reg_b;
reg_c_[i+32] = reg_c[i];
}
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
reg_c_[i+16] = reg_c[i];
}
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64, 32>(const float& reg_a, const float& reg_b, float32_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
reg_c_[i+16] = reg_c[i];
}
}
__device__ void gcnasm_mfma_f32_32x32x2f32(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
auto reg_c_ = reinterpret_cast<float_t*>(reg_c);
for(index_t i = 0; i < 16; i++)
{
reg_c_[i] += reg_a * reg_b;
}
}
__device__ void gcnasm_mfma_f32_16x16x4f32(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x1f32(const float&, const float&, float16_t*);
template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<16, 64>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x1f32<64, 16>(const float& reg_a, const float& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x1f32(const float& reg_a, const float& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<4, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x1f32<8, 64>(const float& reg_a, const float& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x4f16(const half4_t&,
const half4_t&,
float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<32, 64>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64, 32>(const half4_t& reg_a, const half4_t& reg_b, float32_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_32x32x8f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_16x16x16f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x4f16(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<16, 64>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x4f16<64, 16>(const half4_t& reg_a, const half4_t& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x4f16(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<4, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x4f16<8, 64>(const half4_t& reg_a, const half4_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_32x32x2bf16(const ushort2_t&, const ushort2_t&, float32_t*);
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<32, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64, 32>(const ushort2_t& reg_a, const ushort2_t& reg_b, float32_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_32x32x4bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
__device__ void gcnasm_mfma_f32_16x16x8bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_16x16x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t& reg_a, const ushort2_t& reg_b, float16_t* reg_c)
{
}
template <index_t MPerWave, index_t NPerWave>
__device__ void gcnasm_mfma_f32_4x4x2bf16(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c);
template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<4, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
template <>
__device__ void gcnasm_mfma_f32_4x4x2bf16<8, 64>(const ushort2_t& reg_a, const ushort2_t& reg_b, float4_t* reg_c)
{
}
// clang-format on
}
#endif
......@@ -27,6 +27,9 @@
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#else
#include "amd_xdlops_emulate.hpp"
#endif
#endif
......@@ -13,6 +13,19 @@ namespace ck {
using float2_t = float2;
using float4_t = float4;
// float
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
typedef float half4_t __attribute__((ext_vector_type(2)));
typedef float half8_t __attribute__((ext_vector_type(4)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16
using half2_t = half2;
......
......@@ -522,7 +522,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
#elif 0
// cdata = 64, BlockSize = 32, 32x64x3
constexpr index_t BlockSize = 32;
......@@ -559,6 +559,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 32;
constexpr index_t BPerBlock = 16;
constexpr index_t EPerBlock = 3;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0
......
......@@ -758,7 +758,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
#elif 0
// cdata = 64, BlockSize = 32, 32x64x3
constexpr index_t BlockSize = 32;
......@@ -790,6 +790,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 32;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 3;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 64, 64x64x3
......
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc =
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
constexpr auto wei_kcyx_desc =
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
constexpr auto out_nkhw_desc =
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
constexpr index_t N = out_nkhw_desc.GetLength(I0);
constexpr index_t K = out_nkhw_desc.GetLength(I1);
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t ThreadGemmDataPerReadM = 1;
constexpr index_t ThreadGemmDataPerReadN = 1;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmM = K;
constexpr index_t GemmN = N * Ho * Wo;
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
math::integer_divide_ceil(GemmN, GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN>{};
for(index_t i = 0; i < 10; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
}
// warm up
printf("Warn up running %d times...\n", nrepeat);
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
printf("Start running %d times...\n", nrepeat);
cudaDeviceSynchronize();
auto start = std::chrono::steady_clock::now();
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
auto end = std::chrono::steady_clock::now();
float ave_time = std::chrono::duration<float, std::milli>(end - start).count() / nrepeat;
printf("Average elapsed time : %f ms, %f TFlop/s\n",
ave_time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time);
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
......@@ -20,321 +20,18 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
using namespace ck;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 1536;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 73x73
constexpr index_t N = 128;
constexpr index_t C = 160;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35
constexpr index_t N = 128;
constexpr index_t C = 96;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 71x71
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 71;
constexpr index_t WI = 71;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 7x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 320;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 0
// 1x7, 17x17
constexpr index_t N = 128;
constexpr index_t C = 224;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 224;
constexpr index_t Y = 1;
constexpr index_t X = 7;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 1
// 3x3, 299x299 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 299;
constexpr index_t WI = 299;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 32;
constexpr index_t HI = 149;
constexpr index_t WI = 149;
constexpr index_t K = 32;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 17x17, stride 2
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 192;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 35x35
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 96;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 35x35, stride 2
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 384;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x3, 8x8
constexpr index_t N = 128;
constexpr index_t C = 384;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 448;
constexpr index_t Y = 1;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 1>;
using RightPads = Sequence<0, 1>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 0
// 3x1, 8x8
constexpr index_t N = 128;
constexpr index_t C = 448;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 0>;
using RightPads = Sequence<1, 0>;
#elif 1
// 3x3, 147x147
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 147;
constexpr index_t WI = 147;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>;
#elif 1
// 3x3, 73x73
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 73;
constexpr index_t WI = 73;
constexpr index_t K = 96;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14
constexpr index_t N = 128;
constexpr index_t N = 64;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......@@ -343,172 +40,6 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 14x14, stride 2
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 512;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 28x28
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56, stride 2
constexpr index_t N = 128;
constexpr index_t C = 256;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 7x7, 230x230 stride=2
constexpr index_t N = 128;
constexpr index_t C = 3;
constexpr index_t HI = 230;
constexpr index_t WI = 230;
constexpr index_t K = 64;
constexpr index_t Y = 7;
constexpr index_t X = 7;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride = 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1024;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 28x28, stride 2
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 1x1, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 7x7
constexpr index_t N = 128;
constexpr index_t C = 512;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t K = 512;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
// 1x1, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 56x56
constexpr index_t N = 128;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 64;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#endif
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
......@@ -603,7 +134,7 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment