Commit 5771a040 authored by carlushuang's avatar carlushuang
Browse files

fix a bug in general index calculation

parent 5e6cca6f
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP #ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP #define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "data_type_cpu.hpp" #include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp" #include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp" #include "dynamic_buffer_cpu.hpp"
#include <immintrin.h> #include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h> #include <immintrin.h>
namespace ck { namespace ck {
namespace cpu { namespace cpu {
namespace avx2_util { namespace avx2_util {
inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n) inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
ck::index_t i_n = n; ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst); float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src); const float* p_src = reinterpret_cast<const float*>(src);
while(i_n >= 16) while(i_n >= 16)
{ {
_mm256_storeu_ps(p_dst + 0, _mm256_loadu_ps(p_src + 0)); _mm256_storeu_ps(p_dst + 0, _mm256_loadu_ps(p_src + 0));
_mm256_storeu_ps(p_dst + 8, _mm256_loadu_ps(p_src + 8)); _mm256_storeu_ps(p_dst + 8, _mm256_loadu_ps(p_src + 8));
p_dst += 16; p_dst += 16;
p_src += 16; p_src += 16;
i_n -= 16; i_n -= 16;
} }
if(i_n & 8) if(i_n & 8)
{ {
_mm256_storeu_ps(p_dst, _mm256_loadu_ps(p_src)); _mm256_storeu_ps(p_dst, _mm256_loadu_ps(p_src));
p_dst += 8; p_dst += 8;
p_src += 8; p_src += 8;
} }
if(i_n & 4) if(i_n & 4)
{ {
_mm_storeu_ps(p_dst, _mm_loadu_ps(p_src)); _mm_storeu_ps(p_dst, _mm_loadu_ps(p_src));
p_dst += 4; p_dst += 4;
p_src += 4; p_src += 4;
} }
if(i_n & 2) if(i_n & 2)
{ {
_mm_storeu_si64(p_dst, _mm_loadu_si64(p_src)); _mm_storeu_si64(p_dst, _mm_loadu_si64(p_src));
p_dst += 2; p_dst += 2;
p_src += 2; p_src += 2;
} }
if(i_n & 1) if(i_n & 1)
{ {
*p_dst = *p_src; *p_dst = *p_src;
} }
} }
inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n) inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
{ {
// 16-8-4-2-1 pattern // 16-8-4-2-1 pattern
ck::index_t i_n = n; ck::index_t i_n = n;
float* p_dst = reinterpret_cast<float*>(dst); float* p_dst = reinterpret_cast<float*>(dst);
__m256 ymm = _mm256_set1_ps(*reinterpret_cast<const float*>(&value)); __m256 ymm = _mm256_set1_ps(*reinterpret_cast<const float*>(&value));
__m128 xmm = _mm_set1_ps(*reinterpret_cast<const float*>(&value)); __m128 xmm = _mm_set1_ps(*reinterpret_cast<const float*>(&value));
while(i_n >= 16) while(i_n >= 16)
{ {
_mm256_storeu_ps(p_dst + 0, ymm); _mm256_storeu_ps(p_dst + 0, ymm);
_mm256_storeu_ps(p_dst + 8, ymm); _mm256_storeu_ps(p_dst + 8, ymm);
p_dst += 16; p_dst += 16;
i_n -= 16; i_n -= 16;
} }
if(i_n & 8) if(i_n & 8)
{ {
_mm256_storeu_ps(p_dst, ymm); _mm256_storeu_ps(p_dst, ymm);
p_dst += 8; p_dst += 8;
} }
if(i_n & 4) if(i_n & 4)
{ {
_mm_storeu_ps(p_dst, xmm); _mm_storeu_ps(p_dst, xmm);
p_dst += 4; p_dst += 4;
} }
if(i_n & 2) if(i_n & 2)
{ {
_mm_storeu_si64(p_dst, xmm); _mm_storeu_si64(p_dst, xmm);
p_dst += 2; p_dst += 2;
} }
if(i_n & 1) if(i_n & 1)
{ {
*p_dst = *reinterpret_cast<const float*>(&value); *p_dst = *reinterpret_cast<const float*>(&value);
} }
} }
inline void inline void
transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_t stride_src) transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_t stride_src)
{ {
// TODO: use vinsertf128 for better port usage. vpermf128 is slow // TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256 r0, r1, r2, r3, r4, r5, r6, r7; __m256 r0, r1, r2, r3, r4, r5, r6, r7;
__m256 t0, t1, t2, t3, t4, t5, t6, t7; __m256 t0, t1, t2, t3, t4, t5, t6, t7;
float* p_dst = reinterpret_cast<float*>(dst); float* p_dst = reinterpret_cast<float*>(dst);
const float* p_src = reinterpret_cast<const float*>(src); const float* p_src = reinterpret_cast<const float*>(src);
r0 = _mm256_loadu_ps(p_src + 0 * stride_src); r0 = _mm256_loadu_ps(p_src + 0 * stride_src);
r1 = _mm256_loadu_ps(p_src + 1 * stride_src); r1 = _mm256_loadu_ps(p_src + 1 * stride_src);
r2 = _mm256_loadu_ps(p_src + 2 * stride_src); r2 = _mm256_loadu_ps(p_src + 2 * stride_src);
r3 = _mm256_loadu_ps(p_src + 3 * stride_src); r3 = _mm256_loadu_ps(p_src + 3 * stride_src);
r4 = _mm256_loadu_ps(p_src + 4 * stride_src); r4 = _mm256_loadu_ps(p_src + 4 * stride_src);
r5 = _mm256_loadu_ps(p_src + 5 * stride_src); r5 = _mm256_loadu_ps(p_src + 5 * stride_src);
r6 = _mm256_loadu_ps(p_src + 6 * stride_src); r6 = _mm256_loadu_ps(p_src + 6 * stride_src);
r7 = _mm256_loadu_ps(p_src + 7 * stride_src); r7 = _mm256_loadu_ps(p_src + 7 * stride_src);
t0 = _mm256_unpacklo_ps(r0, r1); t0 = _mm256_unpacklo_ps(r0, r1);
t1 = _mm256_unpackhi_ps(r0, r1); t1 = _mm256_unpackhi_ps(r0, r1);
t2 = _mm256_unpacklo_ps(r2, r3); t2 = _mm256_unpacklo_ps(r2, r3);
t3 = _mm256_unpackhi_ps(r2, r3); t3 = _mm256_unpackhi_ps(r2, r3);
t4 = _mm256_unpacklo_ps(r4, r5); t4 = _mm256_unpacklo_ps(r4, r5);
t5 = _mm256_unpackhi_ps(r4, r5); t5 = _mm256_unpackhi_ps(r4, r5);
t6 = _mm256_unpacklo_ps(r6, r7); t6 = _mm256_unpacklo_ps(r6, r7);
t7 = _mm256_unpackhi_ps(r6, r7); t7 = _mm256_unpackhi_ps(r6, r7);
r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0)); r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0));
r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2)); r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2));
r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0)); r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0));
r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2)); r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2));
r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0)); r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0));
r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2)); r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2));
r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0)); r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0));
r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2)); r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2));
t0 = _mm256_permute2f128_ps(r0, r4, 0x20); t0 = _mm256_permute2f128_ps(r0, r4, 0x20);
t1 = _mm256_permute2f128_ps(r1, r5, 0x20); t1 = _mm256_permute2f128_ps(r1, r5, 0x20);
t2 = _mm256_permute2f128_ps(r2, r6, 0x20); t2 = _mm256_permute2f128_ps(r2, r6, 0x20);
t3 = _mm256_permute2f128_ps(r3, r7, 0x20); t3 = _mm256_permute2f128_ps(r3, r7, 0x20);
t4 = _mm256_permute2f128_ps(r0, r4, 0x31); t4 = _mm256_permute2f128_ps(r0, r4, 0x31);
t5 = _mm256_permute2f128_ps(r1, r5, 0x31); t5 = _mm256_permute2f128_ps(r1, r5, 0x31);
t6 = _mm256_permute2f128_ps(r2, r6, 0x31); t6 = _mm256_permute2f128_ps(r2, r6, 0x31);
t7 = _mm256_permute2f128_ps(r3, r7, 0x31); t7 = _mm256_permute2f128_ps(r3, r7, 0x31);
_mm256_storeu_ps(p_dst + 0 * stride_dst, t0); _mm256_storeu_ps(p_dst + 0 * stride_dst, t0);
_mm256_storeu_ps(p_dst + 1 * stride_dst, t1); _mm256_storeu_ps(p_dst + 1 * stride_dst, t1);
_mm256_storeu_ps(p_dst + 2 * stride_dst, t2); _mm256_storeu_ps(p_dst + 2 * stride_dst, t2);
_mm256_storeu_ps(p_dst + 3 * stride_dst, t3); _mm256_storeu_ps(p_dst + 3 * stride_dst, t3);
_mm256_storeu_ps(p_dst + 4 * stride_dst, t4); _mm256_storeu_ps(p_dst + 4 * stride_dst, t4);
_mm256_storeu_ps(p_dst + 5 * stride_dst, t5); _mm256_storeu_ps(p_dst + 5 * stride_dst, t5);
_mm256_storeu_ps(p_dst + 6 * stride_dst, t6); _mm256_storeu_ps(p_dst + 6 * stride_dst, t6);
_mm256_storeu_ps(p_dst + 7 * stride_dst, t7); _mm256_storeu_ps(p_dst + 7 * stride_dst, t7);
} }
} // namespace avx2_util } // namespace avx2_util
using ConvolutionForwardSpecialization_t = using ConvolutionForwardSpecialization_t =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t;
using ConvolutionForwardGemmKSpecialization_t = using ConvolutionForwardGemmKSpecialization_t =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t;
// assume input -> a matrix // assume input -> a matrix
// assume input -> MC * KC // assume input -> MC * KC
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation, typename ElementwiseOperation,
bool BypassTransfer, bool BypassTransfer,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, ConvolutionForwardSpecialization_t ConvForwardSpecialization,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization> ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{ {
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC( constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index&, const Index&,
const DstDesc&, const DstDesc&,
const Index&, const Index&,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
: element_op_(element_op) : element_op_(element_op)
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
N = 1; N = 1;
Hi = 1; Hi = 1;
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; // gemm_m Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; // gemm_m
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; // gemm_k C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; // gemm_k
Ho = 1; Ho = 1;
Wo = Wi; Wo = Wi;
Fy = 1; Fy = 1;
Fx = 1; Fx = 1;
Dy = 1; Dy = 1;
Sy = 1; Sy = 1;
Dx = 1; Dx = 1;
Sx = 1; Sx = 1;
Py = 0; Py = 0;
Px = 0; Px = 0;
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{ {
N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}]; C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
Ho = src_desc.GetTransforms()[Number<2>{}].GetUpperLengths()[Number<0>{}]; Ho = src_desc.GetTransforms()[Number<2>{}].GetUpperLengths()[Number<0>{}];
Wo = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<0>{}]; Wo = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<0>{}];
Fy = 1; Fy = 1;
Fx = 1; Fx = 1;
Dy = 1; Dy = 1;
Sy = src_desc.GetTransforms()[Number<2>{}].coefficients_[Number<0>{}]; Sy = src_desc.GetTransforms()[Number<2>{}].coefficients_[Number<0>{}];
Dx = 1; Dx = 1;
Sx = src_desc.GetTransforms()[Number<3>{}].coefficients_[Number<0>{}]; Sx = src_desc.GetTransforms()[Number<3>{}].coefficients_[Number<0>{}];
Py = 0; Py = 0;
Px = 0; Px = 0;
} }
else else
{ {
N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; N = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; Hi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; Wi = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}]; C = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<3>{}];
Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}]; Ho = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<1>{}];
Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}]; Wo = src_desc.GetTransforms()[Number<9>{}].low_lengths_[Number<2>{}];
Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}]; Fy = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<0>{}];
Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}]; Fx = src_desc.GetTransforms()[Number<10>{}].low_lengths_[Number<1>{}];
Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}]; Dy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<0>{}];
Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}]; Sy = src_desc.GetTransforms()[Number<6>{}].coefficients_[Number<1>{}];
Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}]; Dx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<0>{}];
Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}]; Sx = src_desc.GetTransforms()[Number<7>{}].coefficients_[Number<1>{}];
Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_; Py = src_desc.GetTransforms()[Number<2>{}].left_pad_length_;
Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_; Px = src_desc.GetTransforms()[Number<3>{}].left_pad_length_;
} }
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi = Sx * C; input_offset_acc_wi = Sx * C;
input_offset_ovf_wi_acc_hi = Sy * Wi * C - Wo * Sx * C; input_offset_ovf_wi_acc_hi = Sy * Wi * C - Wo * Sx * C;
input_offset_ovf_hi_acc_n = Hi * Wi * C - Ho * Sy * Wi * C; input_offset_ovf_hi_acc_n = Hi * Wi * C - Ho * Sy * Wi * C;
// input_offset_acc_c = 1; // input_offset_acc_c = 1;
input_offset_ovf_c_acc_x = Dx * C - C; input_offset_ovf_c_acc_x = Dx * C - C;
input_offset_ovf_x_acc_y = Dy * Wi * C - Fx * Dx * C; input_offset_ovf_x_acc_y = Dy * Wi * C - Fx * Dx * C;
src_offset = -Py * Wi * C - Px * C; src_offset = -Py * Wi * C - Px * C;
i_n = 0; i_n = 0;
i_c = 0; i_c = 0;
i_hi = -Py; i_hi = -Py;
i_wi = -Px; i_wi = -Px;
i_ho = 0; i_ho = 0;
i_wo = 0; i_wo = 0;
i_y = 0; i_y = 0;
i_x = 0; i_x = 0;
i_gemm_k = 0; i_gemm_k = 0;
#if 0 #if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, " printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n", "Py:%d, Px:%d\n",
N, N,
Hi, Hi,
Wi, Wi,
C, C,
Ho, Ho,
Wo, Wo,
Fy, Fy,
Fx, Fx,
Dy, Dy,
Sy, Sy,
Dx, Dx,
Sx, Sx,
Py, Py,
Px); Px);
#endif #endif
} }
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx) void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{ {
ck::index_t idx_m = src_slice_origin_idx[Number<0>{}]; ck::index_t idx_m = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}]; ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
i_wi = idx_m; i_wi = idx_m;
i_c = idx_k; i_c = idx_k;
src_offset = i_wi * C + i_c; src_offset = i_wi * C + i_c;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c); // printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{ {
i_wo = idx_m % Wo; i_wo = idx_m % Wo;
i_ho = (idx_m / Wo) % Ho; i_ho = (idx_m / Wo) % Ho;
i_n = (idx_m / Wo) / Ho; i_n = (idx_m / Wo) / Ho;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c = idx_k; i_c = idx_k;
i_x = 0; i_x = 0;
i_y = 0; i_y = 0;
i_hi = i_ho * Sy; i_hi = i_ho * Sy;
i_wi = i_wo * Sx; i_wi = i_wo * Sx;
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c; src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k; i_gemm_k = idx_k;
} }
else else
{ {
i_wo = idx_m % Wo; i_wo = idx_m % Wo;
i_ho = (idx_m / Wo) % Ho; i_ho = (idx_m / Wo) % Ho;
i_n = (idx_m / Wo) / Ho; i_n = (idx_m / Wo) / Ho;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if(idx_k == 0) if(idx_k == 0)
{ {
i_c = 0; i_c = 0;
i_x = 0; i_x = 0;
i_y = 0; i_y = 0;
i_hi = i_ho * Sy - Py; i_hi = i_ho * Sy - Py;
i_wi = i_wo * Sx - Px; i_wi = i_wo * Sx - Px;
} }
else else
{ {
i_c = idx_k % C; i_c = idx_k % C;
i_x = (idx_k / C) % Fx; i_x = (idx_k / C) % Fx;
i_y = (idx_k / C) / Fx; i_y = (idx_k / C) / Fx;
i_hi = i_ho * Sy + i_y * Dy - Py; i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px; i_wi = i_wo * Sx + i_x * Dx - Px;
} }
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c; src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
i_gemm_k = idx_k; i_gemm_k = idx_k;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n", // printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset); // __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
} }
} }
void SetDstSliceOrigin(const DstDesc&, const Index&) {} void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer> template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc& src_desc, void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstBuffer& dst_buf) DstBuffer& dst_buf)
{ {
if constexpr(BypassTransfer) if constexpr(BypassTransfer)
{ {
float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset; float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
dst_buf.p_data_ = p_src; dst_buf.p_data_ = p_src;
} }
else else
{ {
const ck::index_t m_per_block = const ck::index_t m_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
const ck::index_t k_per_block = const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset; const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_); float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block, // printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block); // m_per_block);
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
{ {
ck::index_t i_m_itr = m_per_block; ck::index_t i_m_itr = m_per_block;
// standard 8-4-2-1 pattern // standard 8-4-2-1 pattern
while(i_m_itr >= 8) while(i_m_itr >= 8)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 4 * k_per_block, p_src + 4 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 5 * k_per_block, p_src + 5 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 6 * k_per_block, p_src + 6 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 7 * k_per_block, p_src + 7 * C, k_per_block);
i_m_itr -= 8; i_m_itr -= 8;
p_dst += 8 * k_per_block; p_dst += 8 * k_per_block;
p_src += 8 * C; p_src += 8 * C;
} }
if(i_m_itr & 4) if(i_m_itr & 4)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 2 * k_per_block, p_src + 2 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 3 * k_per_block, p_src + 3 * C, k_per_block);
p_dst += 4 * k_per_block; p_dst += 4 * k_per_block;
p_src += 4 * C; p_src += 4 * C;
} }
if(i_m_itr & 2) if(i_m_itr & 2)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 1 * k_per_block, p_src + 1 * C, k_per_block);
p_dst += 2 * k_per_block; p_dst += 2 * k_per_block;
p_src += 2 * C; p_src += 2 * C;
} }
if(i_m_itr & 1) if(i_m_itr & 1)
{ {
avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block); avx2_util::memcpy32_avx2(p_dst + 0 * k_per_block, p_src + 0 * C, k_per_block);
} }
} }
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization_t::Filter1x1Pad0) ConvolutionForwardSpecialization_t::Filter1x1Pad0)
{ {
ck::index_t i_m_itr = m_per_block; ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo; ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho; ck::index_t i_ho_itr = i_ho;
while(i_m_itr > 0) while(i_m_itr > 0)
{ {
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block); avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
p_dst += k_per_block; p_dst += k_per_block;
i_wo_itr++; i_wo_itr++;
p_src += input_offset_acc_wi; p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo) if(i_wo_itr >= Wo)
{ {
i_wo_itr = 0; i_wo_itr = 0;
i_ho_itr++; i_ho_itr++;
p_src += input_offset_ovf_wi_acc_hi; p_src += input_offset_ovf_wi_acc_hi;
} }
if(i_ho_itr >= Ho) if(i_ho_itr >= Ho)
{ {
i_ho_itr = 0; i_ho_itr = 0;
// i_n++; // i_n++;
p_src += input_offset_ovf_hi_acc_n; p_src += input_offset_ovf_hi_acc_n;
} }
i_m_itr--; i_m_itr--;
} }
} }
else else
{ {
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if constexpr(GemmKSpecialization == if constexpr(GemmKSpecialization ==
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
{ {
// c % k_per_block == 0, so every time k_per_block here is the same // c % k_per_block == 0, so every time k_per_block here is the same
ck::index_t i_m_itr = m_per_block; ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo; ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho; ck::index_t i_ho_itr = i_ho;
ck::index_t i_wi_itr = i_wi; ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi; ck::index_t i_hi_itr = i_hi;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d, // printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d, // src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n", // input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, // __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi, // src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src); // input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const // printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_) // float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1), // + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset), // sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088)); // reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while(i_m_itr > 0) while(i_m_itr > 0)
{ {
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, // printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n", // i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset, // __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src); // p_src);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) && if((*reinterpret_cast<uint32_t*>(&i_hi_itr) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi)) (*reinterpret_cast<uint32_t*>(&i_wi_itr) < Wi))
avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block); avx2_util::memcpy32_avx2(p_dst, p_src, k_per_block);
else else
avx2_util::memset32_avx2(p_dst, 0, k_per_block); avx2_util::memset32_avx2(p_dst, 0, k_per_block);
p_dst += k_per_block; p_dst += k_per_block;
i_wo_itr++; i_wo_itr++;
i_wi_itr += Sx; i_wi_itr += Sx;
p_src += input_offset_acc_wi; p_src += input_offset_acc_wi;
if(i_wo_itr >= Wo) if(i_wo_itr >= Wo)
{ {
i_wo_itr = 0; i_wo_itr = 0;
i_wi_itr -= Wo * Sx; i_wi_itr -= Wo * Sx;
i_ho_itr++; i_ho_itr++;
i_hi_itr += Sy; i_hi_itr += Sy;
p_src += input_offset_ovf_wi_acc_hi; p_src += input_offset_ovf_wi_acc_hi;
} }
if(i_ho_itr >= Ho) if(i_ho_itr >= Ho)
{ {
i_ho_itr = 0; i_ho_itr = 0;
i_hi_itr -= Ho * Sy; i_hi_itr -= Ho * Sy;
// i_n++; // i_n++;
p_src += input_offset_ovf_hi_acc_n; p_src += input_offset_ovf_hi_acc_n;
} }
i_m_itr--; i_m_itr--;
} }
// printf("[%d] \n", __LINE__); // printf("[%d] \n", __LINE__);
} }
else else
{ {
ck::index_t i_m_itr = m_per_block; ck::index_t i_m_itr = m_per_block;
ck::index_t i_wo_itr = i_wo; ck::index_t i_wo_itr = i_wo;
ck::index_t i_ho_itr = i_ho; ck::index_t i_ho_itr = i_ho;
ck::index_t i_wi_itr = i_wi; ck::index_t i_wi_itr = i_wi;
ck::index_t i_hi_itr = i_hi; ck::index_t i_hi_itr = i_hi;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while(i_m_itr > 0) while(i_m_itr > 0)
{ {
/*** go along Gemm K ***/ /*** go along Gemm K ***/
const float* p_src_k = p_src; const float* p_src_k = p_src;
float* p_dst_k = p_dst; float* p_dst_k = p_dst;
ck::index_t i_wi_itr_k = i_wi_itr; ck::index_t i_wi_itr_k = i_wi_itr;
ck::index_t i_hi_itr_k = i_hi_itr; ck::index_t i_hi_itr_k = i_hi_itr;
ck::index_t i_c_itr_k = i_c; ck::index_t i_c_itr_k = i_c;
ck::index_t i_y_itr_k = i_y; ck::index_t i_y_itr_k = i_y;
ck::index_t i_x_itr_k = i_x; ck::index_t i_x_itr_k = i_x;
ck::index_t i_k_itr = k_per_block; ck::index_t i_k_itr = k_per_block;
while(i_k_itr > 0) while(i_k_itr > 0)
{ {
ck::index_t current_k_block = ck::math::min(C - i_c_itr_k, k_per_block); ck::index_t current_k_block = ck::math::min(C - i_c_itr_k, k_per_block);
if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) && if((*reinterpret_cast<uint32_t*>(&i_hi_itr_k) < Hi) &&
(*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi)) (*reinterpret_cast<uint32_t*>(&i_wi_itr_k) < Wi))
avx2_util::memcpy32_avx2(p_dst_k, p_src_k, current_k_block); avx2_util::memcpy32_avx2(p_dst_k, p_src_k, current_k_block);
else else
avx2_util::memset32_avx2(p_dst_k, 0, current_k_block); avx2_util::memset32_avx2(p_dst_k, 0, current_k_block);
p_dst_k += current_k_block; p_dst_k += current_k_block;
p_src_k += current_k_block; p_src_k += current_k_block;
i_c_itr_k += current_k_block; i_c_itr_k += current_k_block;
if(i_c_itr_k >= C) if(i_c_itr_k >= C)
{ {
i_c_itr_k = 0; i_c_itr_k = 0;
i_x_itr_k++; i_x_itr_k++;
i_wi_itr_k += Dx; i_wi_itr_k += Dx;
p_src_k += input_offset_ovf_c_acc_x; p_src_k += input_offset_ovf_c_acc_x;
} }
if(i_x_itr_k >= Fx) if(i_x_itr_k >= Fx)
{ {
i_x_itr_k = 0; i_x_itr_k = 0;
i_y_itr_k++; i_y_itr_k++;
i_hi_itr_k += Dy; i_wi_itr_k -= Dx * Fx;
p_src_k += input_offset_ovf_x_acc_y; i_hi_itr_k += Dy;
} p_src_k += input_offset_ovf_x_acc_y;
}
i_k_itr -= current_k_block;
} i_k_itr -= current_k_block;
/*** go along Gemm K ***/ }
/*** go along Gemm K ***/
p_dst += k_per_block;
p_dst += k_per_block;
i_wo_itr++;
i_wi_itr += Sx; i_wo_itr++;
p_src += input_offset_acc_wi; i_wi_itr += Sx;
if(i_wo_itr >= Wo) p_src += input_offset_acc_wi;
{ if(i_wo_itr >= Wo)
i_wo_itr = 0; {
i_wi_itr -= Wo * Sx; i_wo_itr = 0;
i_ho_itr++; i_wi_itr -= Wo * Sx;
i_hi_itr += Sy; i_ho_itr++;
p_src += input_offset_ovf_wi_acc_hi; i_hi_itr += Sy;
} p_src += input_offset_ovf_wi_acc_hi;
}
if(i_ho_itr >= Ho)
{ if(i_ho_itr >= Ho)
i_ho_itr = 0; {
i_hi_itr -= Ho * Sy; i_ho_itr = 0;
// i_n++; i_hi_itr -= Ho * Sy;
p_src += input_offset_ovf_hi_acc_n; // i_n++;
} p_src += input_offset_ovf_hi_acc_n;
}
i_m_itr--;
} i_m_itr--;
} }
} }
} }
} }
}
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
{ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}]; {
if constexpr(ConvForwardSpecialization == ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) if constexpr(ConvForwardSpecialization ==
{ ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset); {
i_c += move_k; // printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
src_offset += move_k; i_c += move_k;
} src_offset += move_k;
else if constexpr(ConvForwardSpecialization == }
ConvolutionForwardSpecialization_t::Filter1x1Pad0) else if constexpr(ConvForwardSpecialization ==
{ ConvolutionForwardSpecialization_t::Filter1x1Pad0)
i_c += move_k; {
src_offset += move_k; i_c += move_k;
} src_offset += move_k;
else }
{ else
if constexpr(GemmKSpecialization == {
ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC) if constexpr(GemmKSpecialization ==
{ ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC)
// c % k_per_block == 0, so every time k_per_block here is the same {
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h // c % k_per_block == 0, so every time k_per_block here is the same
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w // ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x); // iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// fflush(stdout); // printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
// TODO: branch seems weird
i_c += move_k;
src_offset += move_k; i_c += move_k;
src_offset += move_k;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if(i_c >= C)
{ if(i_c >= C)
i_c = 0; {
i_x++; i_c = 0;
i_wi += Dx; i_x++;
src_offset += Dx * C - C; i_wi += Dx;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset); src_offset += Dx * C - C;
} // printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if(i_x >= Fx) }
{ if(i_x >= Fx)
i_x = 0; {
i_y++; i_x = 0;
i_wi = i_wi - Fx * Dx; i_y++;
i_hi += Dy; i_wi = i_wi - Fx * Dx;
i_hi += Dy;
src_offset += Dy * Wi * C - Fx * Dx * C;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset); src_offset += Dy * Wi * C - Fx * Dx * C;
} // printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout); // printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
} // i_hi, i_wi, src_offset); fflush(stdout);
else }
{ else
i_gemm_k += move_k; {
i_gemm_k += move_k;
i_c = i_gemm_k % C;
i_x = (i_gemm_k / C) % Fx; i_c = i_gemm_k % C;
i_y = (i_gemm_k / C) / Fx; i_x = (i_gemm_k / C) % Fx;
i_y = (i_gemm_k / C) / Fx;
i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px; i_hi = i_ho * Sy + i_y * Dy - Py;
i_wi = i_wo * Sx + i_x * Dx - Px;
src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
} src_offset = i_n * Hi * Wi * C + i_hi * Wi * C + i_wi * C + i_c;
} }
} }
}
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_; private:
const ElementwiseOperation element_op_;
ck::index_t i_n;
ck::index_t i_c; ck::index_t i_n;
ck::index_t i_hi; ck::index_t i_c;
ck::index_t i_wi; ck::index_t i_hi;
ck::index_t i_ho; ck::index_t i_wi;
ck::index_t i_wo; ck::index_t i_ho;
ck::index_t i_y; ck::index_t i_wo;
ck::index_t i_x; ck::index_t i_y;
ck::index_t i_gemm_k; ck::index_t i_x;
ck::index_t i_gemm_k;
ck::index_t N;
// ck::index_t K; ck::index_t N;
ck::index_t C; // ck::index_t K;
ck::index_t Hi; ck::index_t C;
ck::index_t Wi; ck::index_t Hi;
ck::index_t Ho; ck::index_t Wi;
ck::index_t Wo; ck::index_t Ho;
ck::index_t Wo;
ck::index_t Sy;
ck::index_t Sx; ck::index_t Sy;
ck::index_t Sx;
ck::index_t Dy;
ck::index_t Dx; ck::index_t Dy;
ck::index_t Dx;
ck::index_t Py;
ck::index_t Px; ck::index_t Py;
ck::index_t Px;
ck::index_t Fy;
ck::index_t Fx; ck::index_t Fy;
ck::index_t Fx;
intptr_t input_offset_acc_wi;
intptr_t input_offset_ovf_wi_acc_hi; intptr_t input_offset_acc_wi;
intptr_t input_offset_ovf_hi_acc_n; intptr_t input_offset_ovf_wi_acc_hi;
intptr_t input_offset_ovf_hi_acc_n;
// intptr_t input_offset_acc_c;
intptr_t input_offset_ovf_c_acc_x; // intptr_t input_offset_acc_c;
intptr_t input_offset_ovf_x_acc_y; intptr_t input_offset_ovf_c_acc_x;
intptr_t input_offset_ovf_x_acc_y;
intptr_t src_offset; // keep this as pointer type in case we have negative offset
}; intptr_t src_offset; // keep this as pointer type in case we have negative offset
};
template <typename SrcData,
typename DstData, template <typename SrcData,
typename SrcDesc, typename DstData,
typename DstDesc, typename SrcDesc,
typename ElementwiseOperation, typename DstDesc,
bool BypassTransfer, typename ElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, bool BypassTransfer,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization> ConvolutionForwardSpecialization_t ConvForwardSpecialization,
struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
{ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); {
using Index = MultiIndex<nDim>; static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); // using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC(
const SrcDesc& src_desc, constexpr ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC(
const Index& src_slice_origin, const SrcDesc& src_desc,
const DstDesc& dst_desc, const Index& src_slice_origin,
const Index& dst_slice_origin, const DstDesc& dst_desc,
const ElementwiseOperation& element_op) const Index& dst_slice_origin,
: element_op_(element_op) const ElementwiseOperation& element_op)
{ : element_op_(element_op)
GemmN1 = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<1>{}]; {
GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; GemmN1 = src_desc.GetTransforms()[Number<3>{}].GetUpperLengths()[Number<1>{}];
GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; GemmN = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
} GemmK = src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{ void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
ck::index_t idx_n0 = src_slice_origin_idx[Number<0>{}]; {
ck::index_t idx_k = src_slice_origin_idx[Number<1>{}]; ck::index_t idx_n0 = src_slice_origin_idx[Number<0>{}];
ck::index_t idx_n1 = src_slice_origin_idx[Number<2>{}]; ck::index_t idx_k = src_slice_origin_idx[Number<1>{}];
ck::index_t idx_n1 = src_slice_origin_idx[Number<2>{}];
i_gemm_n = idx_n0 * GemmN1 + idx_n1;
// i_gemm_k = idx_k; i_gemm_n = idx_n0 * GemmN1 + idx_n1;
// i_gemm_k = idx_k;
src_offset = idx_n0 * GemmK * GemmN1 + idx_k + idx_n1 * GemmN1; // Note we transpose here
src_offset = idx_n0 * GemmK * GemmN1 + idx_k + idx_n1 * GemmN1; // Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset); // printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
} // src_offset);
}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
void SetDstSliceOrigin(const DstDesc&, const Index&) {}
template <typename SrcBuffer, typename DstBuffer>
void Run(const SrcDesc&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf) template <typename SrcBuffer, typename DstBuffer>
{ void Run(const SrcDesc&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
if constexpr(BypassTransfer) {
{ if constexpr(BypassTransfer)
// TODO: weight NHWC not support this {
} // TODO: weight NHWC not support this
else }
{ else
const ck::index_t n_per_block = {
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] * const ck::index_t n_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}]; dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}] *
const ck::index_t k_per_block = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<2>{}];
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; const ck::index_t k_per_block =
dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}] // printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// .GetUpperLengths()[Number<0>{}], // dst_desc.GetTransforms()[Number<0>{}]
// dst_desc.GetTransforms()[Number<0>{}] // .GetUpperLengths()[Number<0>{}],
// .GetUpperLengths()[Number<2>{}], // dst_desc.GetTransforms()[Number<0>{}]
// k_per_block); // .GetUpperLengths()[Number<2>{}],
// k_per_block);
const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_); const float* p_src = reinterpret_cast<const float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for(index_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8) // n * k -> n0 * k * n1, n1 = 8, n0 = n/8
{ for(index_t i_n_itr = 0; i_n_itr < n_per_block; i_n_itr += 8)
ck::index_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), 8); {
ck::index_t i_k_itr = k_per_block; ck::index_t current_n_8 = ck::math::min(GemmN - (i_n_itr + i_gemm_n), 8);
if(current_n_8 == 8) ck::index_t i_k_itr = k_per_block;
{ if(current_n_8 == 8)
const float* p_src_k = p_src; {
float* p_dst_k = p_dst; const float* p_src_k = p_src;
while(i_k_itr >= 8) float* p_dst_k = p_dst;
{ while(i_k_itr >= 8)
avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK); {
p_dst_k += 8 * 8; avx2_util::transpose8x8_avx2(p_dst_k, 8, p_src_k, GemmK);
p_src_k += 8; p_dst_k += 8 * 8;
i_k_itr -= 8; p_src_k += 8;
} i_k_itr -= 8;
if(i_k_itr & 4) }
{ if(i_k_itr & 4)
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0]; {
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0]; p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0]; p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0]; p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0]; p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0]; p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0]; p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0]; p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1]; p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1]; p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1]; p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1]; p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1]; p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1]; p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1]; p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k[2 * 8 + 0] = p_src_k[0 * GemmK + 2];
p_dst_k[2 * 8 + 1] = p_src_k[1 * GemmK + 2]; p_dst_k[2 * 8 + 0] = p_src_k[0 * GemmK + 2];
p_dst_k[2 * 8 + 2] = p_src_k[2 * GemmK + 2]; p_dst_k[2 * 8 + 1] = p_src_k[1 * GemmK + 2];
p_dst_k[2 * 8 + 3] = p_src_k[3 * GemmK + 2]; p_dst_k[2 * 8 + 2] = p_src_k[2 * GemmK + 2];
p_dst_k[2 * 8 + 4] = p_src_k[4 * GemmK + 2]; p_dst_k[2 * 8 + 3] = p_src_k[3 * GemmK + 2];
p_dst_k[2 * 8 + 5] = p_src_k[5 * GemmK + 2]; p_dst_k[2 * 8 + 4] = p_src_k[4 * GemmK + 2];
p_dst_k[2 * 8 + 6] = p_src_k[6 * GemmK + 2]; p_dst_k[2 * 8 + 5] = p_src_k[5 * GemmK + 2];
p_dst_k[2 * 8 + 7] = p_src_k[7 * GemmK + 2]; p_dst_k[2 * 8 + 6] = p_src_k[6 * GemmK + 2];
p_dst_k[2 * 8 + 7] = p_src_k[7 * GemmK + 2];
p_dst_k[3 * 8 + 0] = p_src_k[0 * GemmK + 3];
p_dst_k[3 * 8 + 1] = p_src_k[1 * GemmK + 3]; p_dst_k[3 * 8 + 0] = p_src_k[0 * GemmK + 3];
p_dst_k[3 * 8 + 2] = p_src_k[2 * GemmK + 3]; p_dst_k[3 * 8 + 1] = p_src_k[1 * GemmK + 3];
p_dst_k[3 * 8 + 3] = p_src_k[3 * GemmK + 3]; p_dst_k[3 * 8 + 2] = p_src_k[2 * GemmK + 3];
p_dst_k[3 * 8 + 4] = p_src_k[4 * GemmK + 3]; p_dst_k[3 * 8 + 3] = p_src_k[3 * GemmK + 3];
p_dst_k[3 * 8 + 5] = p_src_k[5 * GemmK + 3]; p_dst_k[3 * 8 + 4] = p_src_k[4 * GemmK + 3];
p_dst_k[3 * 8 + 6] = p_src_k[6 * GemmK + 3]; p_dst_k[3 * 8 + 5] = p_src_k[5 * GemmK + 3];
p_dst_k[3 * 8 + 7] = p_src_k[7 * GemmK + 3]; p_dst_k[3 * 8 + 6] = p_src_k[6 * GemmK + 3];
p_dst_k[3 * 8 + 7] = p_src_k[7 * GemmK + 3];
p_dst_k += 4 * 8;
p_src_k += 4; p_dst_k += 4 * 8;
} p_src_k += 4;
if(i_k_itr & 2) }
{ if(i_k_itr & 2)
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0]; {
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0]; p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0]; p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0]; p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0]; p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0]; p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0]; p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0]; p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1]; p_dst_k[1 * 8 + 0] = p_src_k[0 * GemmK + 1];
p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1]; p_dst_k[1 * 8 + 1] = p_src_k[1 * GemmK + 1];
p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1]; p_dst_k[1 * 8 + 2] = p_src_k[2 * GemmK + 1];
p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1]; p_dst_k[1 * 8 + 3] = p_src_k[3 * GemmK + 1];
p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1]; p_dst_k[1 * 8 + 4] = p_src_k[4 * GemmK + 1];
p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1]; p_dst_k[1 * 8 + 5] = p_src_k[5 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1]; p_dst_k[1 * 8 + 6] = p_src_k[6 * GemmK + 1];
p_dst_k[1 * 8 + 7] = p_src_k[7 * GemmK + 1];
p_dst_k += 2 * 8;
p_src_k += 2; p_dst_k += 2 * 8;
} p_src_k += 2;
if(i_k_itr & 1) }
{ if(i_k_itr & 1)
p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0]; {
p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0]; p_dst_k[0 * 8 + 0] = p_src_k[0 * GemmK + 0];
p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0]; p_dst_k[0 * 8 + 1] = p_src_k[1 * GemmK + 0];
p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0]; p_dst_k[0 * 8 + 2] = p_src_k[2 * GemmK + 0];
p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0]; p_dst_k[0 * 8 + 3] = p_src_k[3 * GemmK + 0];
p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0]; p_dst_k[0 * 8 + 4] = p_src_k[4 * GemmK + 0];
p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0]; p_dst_k[0 * 8 + 5] = p_src_k[5 * GemmK + 0];
p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0]; p_dst_k[0 * 8 + 6] = p_src_k[6 * GemmK + 0];
} p_dst_k[0 * 8 + 7] = p_src_k[7 * GemmK + 0];
} }
else }
{ else
const float* p_src_k = p_src; {
float* p_dst_k = p_dst; const float* p_src_k = p_src;
float* p_dst_k = p_dst;
for(index_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
{ for(index_t i_sub_n = 0; i_sub_n < 8; i_sub_n++)
for(index_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++) {
{ for(index_t i_sub_k = 0; i_sub_k < k_per_block; i_sub_k++)
ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n; {
ck::index_t i_current_n_itr = i_n_itr + i_sub_n + i_gemm_n;
float v =
i_current_n_itr < GemmN ? p_src_k[i_sub_n * GemmK + i_sub_k] : .0f; float v =
i_current_n_itr < GemmN ? p_src_k[i_sub_n * GemmK + i_sub_k] : .0f;
p_dst_k[i_sub_k * 8 + i_sub_n] = v;
} p_dst_k[i_sub_k * 8 + i_sub_n] = v;
} }
} }
}
p_dst += 8 * k_per_block;
p_src += 8 * GemmK; p_dst += 8 * k_per_block;
} p_src += 8 * GemmK;
} }
} }
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) // src_slice_origin_step_idx need to be known at compile-time, for performance reason
{ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx)
ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}]; {
ck::index_t move_n0 = src_slice_origin_step_idx[Number<0>{}]; ck::index_t move_k = src_slice_origin_step_idx[Number<1>{}];
ck::index_t move_n0 = src_slice_origin_step_idx[Number<0>{}];
// i_gemm_k += move_k;
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset += move_k + move_n0 * GemmK * GemmN1;
} src_offset += move_k + move_n0 * GemmK * GemmN1;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {} // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_; private:
const ElementwiseOperation element_op_;
ck::index_t i_gemm_n;
// ck::index_t i_gemm_k; ck::index_t i_gemm_n;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck::index_t GemmN1; // ck::index_t GemmN0;
ck::index_t GemmN; ck::index_t GemmN1;
ck::index_t GemmK; ck::index_t GemmN;
ck::index_t GemmK;
intptr_t src_offset;
}; intptr_t src_offset;
};
template <typename SrcData,
typename DstData, template <typename SrcData,
typename SrcDesc, typename DstData,
typename DstDesc, typename SrcDesc,
typename ElementwiseOperation, typename DstDesc,
bool BypassTransfer, typename ElementwiseOperation,
ConvolutionForwardSpecialization_t ConvForwardSpecialization, bool BypassTransfer,
ConvolutionForwardGemmKSpecialization_t GemmKSpecialization> ConvolutionForwardSpecialization_t ConvForwardSpecialization,
struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN ConvolutionForwardGemmKSpecialization_t GemmKSpecialization>
{ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension(); {
using Index = MultiIndex<nDim>; static constexpr ck::index_t nDim = SrcDesc::GetNumOfDimension();
using Index = MultiIndex<nDim>;
constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN(
const SrcDesc& src_desc, constexpr ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN(
const Index&, const SrcDesc& src_desc,
const DstDesc& dst_desc, const Index&,
const Index&, const DstDesc& dst_desc,
const ElementwiseOperation& element_op) const Index&,
: element_op_(element_op) const ElementwiseOperation& element_op)
{ : element_op_(element_op)
DstGemmM = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}]; {
DstGemmN = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; DstGemmM = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<0>{}];
DstGemmN = dst_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
src_offset = 0;
dst_offset = 0; src_offset = 0;
} dst_offset = 0;
}
void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
{ void SetSrcSliceOrigin(const SrcDesc&, const Index& src_slice_origin_idx)
if constexpr(BypassTransfer) {
{ if constexpr(BypassTransfer)
auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}]; {
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}]; auto i_src_gemm_m = src_slice_origin_idx[Number<0>{}];
auto i_src_gemm_n = src_slice_origin_idx[Number<1>{}];
src_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
} src_offset = i_src_gemm_m * DstGemmN + i_src_gemm_n;
} }
}
void SetDstSliceOrigin(const DstDesc&, const Index& dst_slice_origin_idx)
{ void SetDstSliceOrigin(const DstDesc&, const Index& dst_slice_origin_idx)
i_dst_gemm_m = dst_slice_origin_idx[Number<0>{}]; {
i_dst_gemm_n = dst_slice_origin_idx[Number<1>{}]; i_dst_gemm_m = dst_slice_origin_idx[Number<0>{}];
i_dst_gemm_n = dst_slice_origin_idx[Number<1>{}];
dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
} dst_offset = i_dst_gemm_m * DstGemmN + i_dst_gemm_n;
}
template <typename SrcBuffer, typename DstBuffer>
void template <typename SrcBuffer, typename DstBuffer>
Run(const SrcDesc& src_desc, SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf) void
{ Run(const SrcDesc& src_desc, SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf)
if constexpr(BypassTransfer) {
{ if constexpr(BypassTransfer)
src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset; {
} src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
else }
{ else
const ck::index_t m_per_block = {
src_desc.GetTransforms()[Number<0>{}] const ck::index_t m_per_block =
.GetUpperLengths()[Number<0>{}]; // must be multiple of 8 src_desc.GetTransforms()[Number<0>{}]
const ck::index_t n_per_block = .GetUpperLengths()[Number<0>{}]; // must be multiple of 8
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}]; const ck::index_t n_per_block =
src_desc.GetTransforms()[Number<0>{}].GetUpperLengths()[Number<1>{}];
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
const ck::index_t current_n = ck::math::min(DstGemmN - i_dst_gemm_n, n_per_block);
const float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset; const float* p_src = reinterpret_cast<float*>(src_buf.p_data_) + src_offset;
float* p_dst = reinterpret_cast<float*>(dst_buf.p_data_) + dst_offset;
ck::index_t i_m_itr = m_per_block;
ck::index_t i_m_itr = m_per_block;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout); // printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while(i_m_itr >= 8) // standard 8-4-2-1 pattern
{ while(i_m_itr >= 8)
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); {
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 4 * DstGemmN, p_src + 4 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 5 * DstGemmN, p_src + 5 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 6 * DstGemmN, p_src + 6 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 7 * DstGemmN, p_src + 7 * n_per_block, current_n);
i_m_itr -= 8;
p_dst += 8 * DstGemmN; i_m_itr -= 8;
p_src += 8 * n_per_block; p_dst += 8 * DstGemmN;
} p_src += 8 * n_per_block;
}
if(i_m_itr & 4)
{ if(i_m_itr & 4)
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); {
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 2 * DstGemmN, p_src + 2 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 3 * DstGemmN, p_src + 3 * n_per_block, current_n);
p_dst += 4 * DstGemmN;
p_src += 4 * n_per_block; p_dst += 4 * DstGemmN;
} p_src += 4 * n_per_block;
}
if(i_m_itr & 2)
{ if(i_m_itr & 2)
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); {
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n); avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
avx2_util::memcpy32_avx2(p_dst + 1 * DstGemmN, p_src + 1 * n_per_block, current_n);
p_dst += 2 * DstGemmN;
p_src += 2 * n_per_block; p_dst += 2 * DstGemmN;
} p_src += 2 * n_per_block;
}
if(i_m_itr & 1)
{ if(i_m_itr & 1)
avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n); {
} avx2_util::memcpy32_avx2(p_dst + 0 * DstGemmN, p_src + 0 * n_per_block, current_n);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
} // printf("xxxx %d\n",__LINE__);fflush(stdout);
} }
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc&, const Index&) {} // src_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveSrcSliceWindow(const SrcDesc&, const Index&) {}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {} // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void MoveDstSliceWindow(const DstDesc&, const Index&) {}
private:
const ElementwiseOperation element_op_; private:
const ElementwiseOperation element_op_;
ck::index_t i_dst_gemm_m;
ck::index_t i_dst_gemm_n; ck::index_t i_dst_gemm_m;
ck::index_t i_dst_gemm_n;
ck::index_t DstGemmM;
ck::index_t DstGemmN; ck::index_t DstGemmM;
ck::index_t DstGemmN;
intptr_t src_offset;
intptr_t dst_offset; intptr_t src_offset;
}; intptr_t dst_offset;
};
} // namespace cpu
} // namespace ck } // namespace cpu
} // namespace ck
#endif
#endif
#include <stdlib.h> #include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp" #include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp" #include "config.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp" #include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp" #include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp" #include "device_operation_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace cpu { namespace cpu {
namespace device { namespace device {
namespace device_conv2d_fwd_avx2_instance { namespace device_conv2d_fwd_avx2_instance {
using InType = float; using InType = float;
using WeiType = float; using WeiType = float;
using OutType = float; using OutType = float;
using AccType = float; using AccType = float;
using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC using InLayout = ck::tensor_layout::gemm::RowMajor; // NHWC
using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; // KYXC
static constexpr bool NonTemporalStore = false; static constexpr bool NonTemporalStore = false;
using PT = ck::tensor_operation::cpu::element_wise::PassThrough; using PT = ck::tensor_operation::cpu::element_wise::PassThrough;
using ThreadwiseGemmAvx2_MxN_4x24_Dispatch = using ThreadwiseGemmAvx2_MxN_4x24_Dispatch =
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType, ck::cpu::ThreadwiseGemmAvx2_MxN_4x24_Dispatch<InType,
WeiType, WeiType,
OutType, OutType,
InLayout, InLayout,
WeiLayout, WeiLayout,
NonTemporalStore>; NonTemporalStore>;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 = static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; ck::tensor_operation::cpu::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto DefaultGemmKLoop = static constexpr auto DefaultGemmKLoop =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::DefaultGemmKLoop;
static constexpr auto GemmKLoopOverC = static constexpr auto GemmKLoopOverC =
ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC; ck::tensor_operation::cpu::device::ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC;
static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK; static constexpr auto LoopOver_MNK = ck::tensor_operation::cpu::device::LoopOver_MNK;
static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN; static constexpr auto LoopOver_MKN = ck::tensor_operation::cpu::device::LoopOver_MKN;
// clang-format off // clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \ #define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \ DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf> DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on // clang-format on
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple<
using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances = std::tuple< // clang-format off
// clang-format off DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 120, 64, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 144, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 512, 240, 128, 4, 24, true, true, false), // DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false), DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 288, 128, 4, 24, true, true, false)>; // clang-format on
// clang-format on
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances)
void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk(std::vector<DeviceConvFwdPtr<PT, PT, PT>>& instances) {
{ ck::tensor_operation::device::add_device_operation_instances(
ck::tensor_operation::device::add_device_operation_instances( instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{});
instances, device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances{}); }
}
} // namespace device_conv2d_fwd_avx2_instance
} // namespace device_conv2d_fwd_avx2_instance } // namespace device
} // namespace device } // namespace cpu
} // namespace cpu } // namespace tensor_operation
} // namespace tensor_operation } // namespace ck
} // namespace ck
...@@ -37,26 +37,53 @@ using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; ...@@ -37,26 +37,53 @@ using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
template <typename T> template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) static bool
check_out(const Tensor<T>& ref, const Tensor<T>& result, double nrms, int per_pixel_check = 0)
{ {
int error_count = 0; int error_count = 0;
float max_diff = 1e-6; float max_diff = 1e-5;
double square_difference = .0;
double mag1 = .0;
double mag2 = .0;
for(int i = 0; i < ref.mData.size(); ++i) for(int i = 0; i < ref.mData.size(); ++i)
{ {
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); double ri = (double)ref.mData[i];
if(max_diff < diff) double pi = (double)result.mData[i];
double d = ri - pi;
if(per_pixel_check)
{ {
error_count++; if(max_diff < std::abs(d))
printf("idx:%3d, ref:%f, res:%f (diff:%f)\n", {
i, error_count++;
double(ref.mData[i]), printf("idx:%3d, ref:%f, res:%f (diff:%f)\n",
double(result.mData[i]), i,
diff); double(ref.mData[i]),
double(result.mData[i]),
d);
}
} }
square_difference += d * d;
if(std::abs(mag1) < std::abs(ri))
mag1 = ri;
if(std::abs(mag2) < std::abs(pi))
mag2 = pi;
} }
return error_count == 0; double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()});
double computed_nrms = std::sqrt(square_difference) / (std::sqrt(ref.mData.size()) * mag);
if(computed_nrms >= nrms)
printf("nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f\n",
computed_nrms,
mag1,
mag2,
nrms);
return computed_nrms < nrms && error_count == 0;
} }
float calculate_gflops() {} float calculate_gflops() {}
...@@ -171,20 +198,28 @@ int main(int argc, char* argv[]) ...@@ -171,20 +198,28 @@ int main(int argc, char* argv[])
<< ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w << ", Dilation(H, W):" << conv_dilation_h << ", " << conv_dilation_w
<< ", Threads:" << omp_get_max_threads() << std::endl; << ", Threads:" << omp_get_max_threads() << std::endl;
int per_pixel_check = 0;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break;
case 1: case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); // in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{}); // wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check = 1;
break; break;
case 2: case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
#define PACK_32(v24, v16, v8, v0) \ #define PACK_32(v24, v16, v8, v0) \
...@@ -310,7 +345,10 @@ int main(int argc, char* argv[]) ...@@ -310,7 +345,10 @@ int main(int argc, char* argv[])
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
if(!check_out(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result)) if(!check_out(out_n_k_ho_wo_host_result,
out_n_k_ho_wo_device_result,
1e-6,
per_pixel_check))
{ {
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
success = false; success = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment