Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5771a040
Commit
5771a040
authored
Apr 27, 2022
by
carlushuang
Browse files
fix a bug in general index calculation
parent
5e6cca6f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1217 additions
and
1179 deletions
+1217
-1179
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+1085
-1084
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+79
-80
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+53
-15
No files found.
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
5771a040
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_AVX2_SPECIALIZED_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "data_type_cpu.hpp"
#include "data_type_cpu.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "../../gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include <immintrin.h>
#include <immintrin.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
#include <immintrin.h>
namespace
ck
{
namespace
ck
{
namespace
cpu
{
namespace
cpu
{
namespace
avx2_util
{
namespace
avx2_util
{
inline
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
)
inline
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
)
{
{
// 16-8-4-2-1 pattern
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
while
(
i_n
>=
16
)
while
(
i_n
>=
16
)
{
{
_mm256_storeu_ps
(
p_dst
+
0
,
_mm256_loadu_ps
(
p_src
+
0
));
_mm256_storeu_ps
(
p_dst
+
0
,
_mm256_loadu_ps
(
p_src
+
0
));
_mm256_storeu_ps
(
p_dst
+
8
,
_mm256_loadu_ps
(
p_src
+
8
));
_mm256_storeu_ps
(
p_dst
+
8
,
_mm256_loadu_ps
(
p_src
+
8
));
p_dst
+=
16
;
p_dst
+=
16
;
p_src
+=
16
;
p_src
+=
16
;
i_n
-=
16
;
i_n
-=
16
;
}
}
if
(
i_n
&
8
)
if
(
i_n
&
8
)
{
{
_mm256_storeu_ps
(
p_dst
,
_mm256_loadu_ps
(
p_src
));
_mm256_storeu_ps
(
p_dst
,
_mm256_loadu_ps
(
p_src
));
p_dst
+=
8
;
p_dst
+=
8
;
p_src
+=
8
;
p_src
+=
8
;
}
}
if
(
i_n
&
4
)
if
(
i_n
&
4
)
{
{
_mm_storeu_ps
(
p_dst
,
_mm_loadu_ps
(
p_src
));
_mm_storeu_ps
(
p_dst
,
_mm_loadu_ps
(
p_src
));
p_dst
+=
4
;
p_dst
+=
4
;
p_src
+=
4
;
p_src
+=
4
;
}
}
if
(
i_n
&
2
)
if
(
i_n
&
2
)
{
{
_mm_storeu_si64
(
p_dst
,
_mm_loadu_si64
(
p_src
));
_mm_storeu_si64
(
p_dst
,
_mm_loadu_si64
(
p_src
));
p_dst
+=
2
;
p_dst
+=
2
;
p_src
+=
2
;
p_src
+=
2
;
}
}
if
(
i_n
&
1
)
if
(
i_n
&
1
)
{
{
*
p_dst
=
*
p_src
;
*
p_dst
=
*
p_src
;
}
}
}
}
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
inline
void
memset32_avx2
(
void
*
dst
,
const
int32_t
value
,
const
ck
::
index_t
n
)
{
{
// 16-8-4-2-1 pattern
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
ck
::
index_t
i_n
=
n
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
__m256
ymm
=
_mm256_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
__m256
ymm
=
_mm256_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
__m128
xmm
=
_mm_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
__m128
xmm
=
_mm_set1_ps
(
*
reinterpret_cast
<
const
float
*>
(
&
value
));
while
(
i_n
>=
16
)
while
(
i_n
>=
16
)
{
{
_mm256_storeu_ps
(
p_dst
+
0
,
ymm
);
_mm256_storeu_ps
(
p_dst
+
0
,
ymm
);
_mm256_storeu_ps
(
p_dst
+
8
,
ymm
);
_mm256_storeu_ps
(
p_dst
+
8
,
ymm
);
p_dst
+=
16
;
p_dst
+=
16
;
i_n
-=
16
;
i_n
-=
16
;
}
}
if
(
i_n
&
8
)
if
(
i_n
&
8
)
{
{
_mm256_storeu_ps
(
p_dst
,
ymm
);
_mm256_storeu_ps
(
p_dst
,
ymm
);
p_dst
+=
8
;
p_dst
+=
8
;
}
}
if
(
i_n
&
4
)
if
(
i_n
&
4
)
{
{
_mm_storeu_ps
(
p_dst
,
xmm
);
_mm_storeu_ps
(
p_dst
,
xmm
);
p_dst
+=
4
;
p_dst
+=
4
;
}
}
if
(
i_n
&
2
)
if
(
i_n
&
2
)
{
{
_mm_storeu_si64
(
p_dst
,
xmm
);
_mm_storeu_si64
(
p_dst
,
xmm
);
p_dst
+=
2
;
p_dst
+=
2
;
}
}
if
(
i_n
&
1
)
if
(
i_n
&
1
)
{
{
*
p_dst
=
*
reinterpret_cast
<
const
float
*>
(
&
value
);
*
p_dst
=
*
reinterpret_cast
<
const
float
*>
(
&
value
);
}
}
}
}
inline
void
inline
void
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
)
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
)
{
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
__m256
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
;
__m256
t0
,
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
r0
=
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
);
r0
=
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
);
r1
=
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
);
r1
=
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
);
r2
=
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
);
r2
=
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
);
r3
=
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
);
r3
=
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
);
r4
=
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
);
r4
=
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
);
r5
=
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
);
r5
=
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
);
r6
=
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
);
r6
=
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
);
r7
=
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
);
r7
=
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
);
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
t2
=
_mm256_unpacklo_ps
(
r2
,
r3
);
t2
=
_mm256_unpacklo_ps
(
r2
,
r3
);
t3
=
_mm256_unpackhi_ps
(
r2
,
r3
);
t3
=
_mm256_unpackhi_ps
(
r2
,
r3
);
t4
=
_mm256_unpacklo_ps
(
r4
,
r5
);
t4
=
_mm256_unpacklo_ps
(
r4
,
r5
);
t5
=
_mm256_unpackhi_ps
(
r4
,
r5
);
t5
=
_mm256_unpackhi_ps
(
r4
,
r5
);
t6
=
_mm256_unpacklo_ps
(
r6
,
r7
);
t6
=
_mm256_unpacklo_ps
(
r6
,
r7
);
t7
=
_mm256_unpackhi_ps
(
r6
,
r7
);
t7
=
_mm256_unpackhi_ps
(
r6
,
r7
);
r0
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r0
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r1
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r1
=
_mm256_shuffle_ps
(
t0
,
t2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r2
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r2
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r3
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r3
=
_mm256_shuffle_ps
(
t1
,
t3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r4
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r4
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r5
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r5
=
_mm256_shuffle_ps
(
t4
,
t6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r6
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r6
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
r7
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
r7
=
_mm256_shuffle_ps
(
t5
,
t7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
t0
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x20
);
t0
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x20
);
t1
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x20
);
t1
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x20
);
t2
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x20
);
t2
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x20
);
t3
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x20
);
t3
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x20
);
t4
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x31
);
t4
=
_mm256_permute2f128_ps
(
r0
,
r4
,
0x31
);
t5
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x31
);
t5
=
_mm256_permute2f128_ps
(
r1
,
r5
,
0x31
);
t6
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x31
);
t6
=
_mm256_permute2f128_ps
(
r2
,
r6
,
0x31
);
t7
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x31
);
t7
=
_mm256_permute2f128_ps
(
r3
,
r7
,
0x31
);
_mm256_storeu_ps
(
p_dst
+
0
*
stride_dst
,
t0
);
_mm256_storeu_ps
(
p_dst
+
0
*
stride_dst
,
t0
);
_mm256_storeu_ps
(
p_dst
+
1
*
stride_dst
,
t1
);
_mm256_storeu_ps
(
p_dst
+
1
*
stride_dst
,
t1
);
_mm256_storeu_ps
(
p_dst
+
2
*
stride_dst
,
t2
);
_mm256_storeu_ps
(
p_dst
+
2
*
stride_dst
,
t2
);
_mm256_storeu_ps
(
p_dst
+
3
*
stride_dst
,
t3
);
_mm256_storeu_ps
(
p_dst
+
3
*
stride_dst
,
t3
);
_mm256_storeu_ps
(
p_dst
+
4
*
stride_dst
,
t4
);
_mm256_storeu_ps
(
p_dst
+
4
*
stride_dst
,
t4
);
_mm256_storeu_ps
(
p_dst
+
5
*
stride_dst
,
t5
);
_mm256_storeu_ps
(
p_dst
+
5
*
stride_dst
,
t5
);
_mm256_storeu_ps
(
p_dst
+
6
*
stride_dst
,
t6
);
_mm256_storeu_ps
(
p_dst
+
6
*
stride_dst
,
t6
);
_mm256_storeu_ps
(
p_dst
+
7
*
stride_dst
,
t7
);
_mm256_storeu_ps
(
p_dst
+
7
*
stride_dst
,
t7
);
}
}
}
// namespace avx2_util
}
// namespace avx2_util
using
ConvolutionForwardSpecialization_t
=
using
ConvolutionForwardSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
;
using
ConvolutionForwardGemmKSpecialization_t
=
using
ConvolutionForwardGemmKSpecialization_t
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
;
// assume input -> a matrix
// assume input -> a matrix
// assume input -> MC * KC
// assume input -> MC * KC
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
ElementwiseOperation
,
bool
BypassTransfer
,
bool
BypassTransfer
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
{
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
(
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
Index
&
,
const
DstDesc
&
,
const
DstDesc
&
,
const
Index
&
,
const
Index
&
,
const
ElementwiseOperation
&
element_op
)
const
ElementwiseOperation
&
element_op
)
:
element_op_
(
element_op
)
:
element_op_
(
element_op
)
{
{
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
N
=
1
;
N
=
1
;
Hi
=
1
;
Hi
=
1
;
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
// gemm_m
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
// gemm_m
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// gemm_k
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// gemm_k
Ho
=
1
;
Ho
=
1
;
Wo
=
Wi
;
Wo
=
Wi
;
Fy
=
1
;
Fy
=
1
;
Fx
=
1
;
Fx
=
1
;
Dy
=
1
;
Dy
=
1
;
Sy
=
1
;
Sy
=
1
;
Dx
=
1
;
Dx
=
1
;
Sx
=
1
;
Sx
=
1
;
Py
=
0
;
Py
=
0
;
Px
=
0
;
Px
=
0
;
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Fy
=
1
;
Fy
=
1
;
Fx
=
1
;
Fx
=
1
;
Dy
=
1
;
Dy
=
1
;
Sy
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sy
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
coefficients_
[
Number
<
0
>
{}];
Dx
=
1
;
Dx
=
1
;
Sx
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sx
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
coefficients_
[
Number
<
0
>
{}];
Py
=
0
;
Py
=
0
;
Px
=
0
;
Px
=
0
;
}
}
else
else
{
{
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
N
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Hi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
Wi
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
C
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
3
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Ho
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
Wo
=
src_desc
.
GetTransforms
()[
Number
<
9
>
{}].
low_lengths_
[
Number
<
2
>
{}];
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
Fy
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
0
>
{}];
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Fx
=
src_desc
.
GetTransforms
()[
Number
<
10
>
{}].
low_lengths_
[
Number
<
1
>
{}];
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
Dy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
Sy
=
src_desc
.
GetTransforms
()[
Number
<
6
>
{}].
coefficients_
[
Number
<
1
>
{}];
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
Dx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
0
>
{}];
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
Sx
=
src_desc
.
GetTransforms
()[
Number
<
7
>
{}].
coefficients_
[
Number
<
1
>
{}];
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
Py
=
src_desc
.
GetTransforms
()[
Number
<
2
>
{}].
left_pad_length_
;
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
Px
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
left_pad_length_
;
}
}
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
input_offset_acc_wi
=
Sx
*
C
;
input_offset_acc_wi
=
Sx
*
C
;
input_offset_ovf_wi_acc_hi
=
Sy
*
Wi
*
C
-
Wo
*
Sx
*
C
;
input_offset_ovf_wi_acc_hi
=
Sy
*
Wi
*
C
-
Wo
*
Sx
*
C
;
input_offset_ovf_hi_acc_n
=
Hi
*
Wi
*
C
-
Ho
*
Sy
*
Wi
*
C
;
input_offset_ovf_hi_acc_n
=
Hi
*
Wi
*
C
-
Ho
*
Sy
*
Wi
*
C
;
// input_offset_acc_c = 1;
// input_offset_acc_c = 1;
input_offset_ovf_c_acc_x
=
Dx
*
C
-
C
;
input_offset_ovf_c_acc_x
=
Dx
*
C
-
C
;
input_offset_ovf_x_acc_y
=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
input_offset_ovf_x_acc_y
=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
src_offset
=
-
Py
*
Wi
*
C
-
Px
*
C
;
src_offset
=
-
Py
*
Wi
*
C
-
Px
*
C
;
i_n
=
0
;
i_n
=
0
;
i_c
=
0
;
i_c
=
0
;
i_hi
=
-
Py
;
i_hi
=
-
Py
;
i_wi
=
-
Px
;
i_wi
=
-
Px
;
i_ho
=
0
;
i_ho
=
0
;
i_wo
=
0
;
i_wo
=
0
;
i_y
=
0
;
i_y
=
0
;
i_x
=
0
;
i_x
=
0
;
i_gemm_k
=
0
;
i_gemm_k
=
0
;
#if 0
#if 0
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
printf("N:%d, Hi:%d, Wi:%d, C:%d, Ho:%d, Wo:%d, Fy:%d, Fx:%d, Dy:%d, Sy:%d, Dx:%d, Sx:%d, "
"Py:%d, Px:%d\n",
"Py:%d, Px:%d\n",
N,
N,
Hi,
Hi,
Wi,
Wi,
C,
C,
Ho,
Ho,
Wo,
Wo,
Fy,
Fy,
Fx,
Fx,
Dy,
Dy,
Sy,
Sy,
Dx,
Dx,
Sx,
Sx,
Py,
Py,
Px);
Px);
#endif
#endif
}
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
{
ck
::
index_t
idx_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
i_wi
=
idx_m
;
i_wi
=
idx_m
;
i_c
=
idx_k
;
i_c
=
idx_k
;
src_offset
=
i_wi
*
C
+
i_c
;
src_offset
=
i_wi
*
C
+
i_c
;
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
// printf("src_offset:%d, i_wi:%d, i_c:%d\n", src_offset, i_wi, i_c);
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
{
i_wo
=
idx_m
%
Wo
;
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
i_c
=
idx_k
;
i_c
=
idx_k
;
i_x
=
0
;
i_x
=
0
;
i_y
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
;
i_hi
=
i_ho
*
Sy
;
i_wi
=
i_wo
*
Sx
;
i_wi
=
i_wo
*
Sx
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
i_gemm_k
=
idx_k
;
}
}
else
else
{
{
i_wo
=
idx_m
%
Wo
;
i_wo
=
idx_m
%
Wo
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_ho
=
(
idx_m
/
Wo
)
%
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
i_n
=
(
idx_m
/
Wo
)
/
Ho
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
(
idx_k
==
0
)
if
(
idx_k
==
0
)
{
{
i_c
=
0
;
i_c
=
0
;
i_x
=
0
;
i_x
=
0
;
i_y
=
0
;
i_y
=
0
;
i_hi
=
i_ho
*
Sy
-
Py
;
i_hi
=
i_ho
*
Sy
-
Py
;
i_wi
=
i_wo
*
Sx
-
Px
;
i_wi
=
i_wo
*
Sx
-
Px
;
}
}
else
else
{
{
i_c
=
idx_k
%
C
;
i_c
=
idx_k
%
C
;
i_x
=
(
idx_k
/
C
)
%
Fx
;
i_x
=
(
idx_k
/
C
)
%
Fx
;
i_y
=
(
idx_k
/
C
)
/
Fx
;
i_y
=
(
idx_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
}
}
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
i_gemm_k
=
idx_k
;
i_gemm_k
=
idx_k
;
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// printf("[%d] i_wo:%d, i_ho:%d, i_wi:%d, i_hi:%d, src_offset:%d\n",
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
// __LINE__, i_wo, i_ho, i_wi, i_hi, src_offset);
}
}
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
src_desc
,
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
dst_buf
.
p_data_
=
p_src
;
dst_buf
.
p_data_
=
p_src
;
}
}
else
else
{
{
const
ck
::
index_t
m_per_block
=
const
ck
::
index_t
m_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
const
ck
::
index_t
k_per_block
=
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// printf("src offset:%d, k_per_block:%d, m_per_block:%d\n", src_offset, k_per_block,
// m_per_block);
// m_per_block);
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// standard 8-4-2-1 pattern
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
while
(
i_m_itr
>=
8
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
);
i_m_itr
-=
8
;
i_m_itr
-=
8
;
p_dst
+=
8
*
k_per_block
;
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
C
;
p_src
+=
8
*
C
;
}
}
if
(
i_m_itr
&
4
)
if
(
i_m_itr
&
4
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
p_dst
+=
4
*
k_per_block
;
p_dst
+=
4
*
k_per_block
;
p_src
+=
4
*
C
;
p_src
+=
4
*
C
;
}
}
if
(
i_m_itr
&
2
)
if
(
i_m_itr
&
2
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+=
2
*
k_per_block
;
p_dst
+=
2
*
k_per_block
;
p_src
+=
2
*
C
;
p_src
+=
2
*
C
;
}
}
if
(
i_m_itr
&
1
)
if
(
i_m_itr
&
1
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_ho_itr
=
i_ho
;
while
(
i_m_itr
>
0
)
while
(
i_m_itr
>
0
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
p_dst
+=
k_per_block
;
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wo_itr
++
;
p_src
+=
input_offset_acc_wi
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
if
(
i_wo_itr
>=
Wo
)
{
{
i_wo_itr
=
0
;
i_wo_itr
=
0
;
i_ho_itr
++
;
i_ho_itr
++
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
}
if
(
i_ho_itr
>=
Ho
)
if
(
i_ho_itr
>=
Ho
)
{
{
i_ho_itr
=
0
;
i_ho_itr
=
0
;
// i_n++;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
}
i_m_itr
--
;
i_m_itr
--
;
}
}
}
}
else
else
{
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
if
constexpr
(
GemmKSpecialization
==
if
constexpr
(
GemmKSpecialization
==
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
{
{
// c % k_per_block == 0, so every time k_per_block here is the same
// c % k_per_block == 0, so every time k_per_block here is the same
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d, i_hi_itr:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// src_offset:%d, input_offset_acc_wi:%d,
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// input_offset_ovf_wi_acc_hi:%d,input_offset_ovf_hi_acc_n:%d, %p(%p)\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// src_offset, input_offset_acc_wi, input_offset_ovf_wi_acc_hi,
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// input_offset_ovf_hi_acc_n, src_buf.p_data_, p_src);
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// printf("%p %p %p, %d, %x, %p\n",src_buf.p_data_, reinterpret_cast<const
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// float*>(src_buf.p_data_) + 1, reinterpret_cast<const float*>(src_buf.p_data_)
// + ck::index_t(-1),
// + ck::index_t(-1),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// sizeof(src_offset), *reinterpret_cast<uint32_t*>(&src_offset),
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
// reinterpret_cast<const float*>(src_buf.p_data_) + (-1088));
while
(
i_m_itr
>
0
)
while
(
i_m_itr
>
0
)
{
{
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// printf("[%d] i_m_itr:%d, i_wo_itr:%d, i_ho_itr:%d, i_wi_itr:%d,
// i_hi_itr:%d, src_offset:%d -> %p\n",
// i_hi_itr:%d, src_offset:%d -> %p\n",
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// __LINE__, i_m_itr, i_wo_itr, i_ho_itr, i_wi_itr, i_hi_itr, src_offset,
// p_src);
// p_src);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
else
else
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
p_dst
+=
k_per_block
;
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
i_wi_itr
+=
Sx
;
p_src
+=
input_offset_acc_wi
;
p_src
+=
input_offset_acc_wi
;
if
(
i_wo_itr
>=
Wo
)
if
(
i_wo_itr
>=
Wo
)
{
{
i_wo_itr
=
0
;
i_wo_itr
=
0
;
i_wi_itr
-=
Wo
*
Sx
;
i_wi_itr
-=
Wo
*
Sx
;
i_ho_itr
++
;
i_ho_itr
++
;
i_hi_itr
+=
Sy
;
i_hi_itr
+=
Sy
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
}
if
(
i_ho_itr
>=
Ho
)
if
(
i_ho_itr
>=
Ho
)
{
{
i_ho_itr
=
0
;
i_ho_itr
=
0
;
i_hi_itr
-=
Ho
*
Sy
;
i_hi_itr
-=
Ho
*
Sy
;
// i_n++;
// i_n++;
p_src
+=
input_offset_ovf_hi_acc_n
;
p_src
+=
input_offset_ovf_hi_acc_n
;
}
}
i_m_itr
--
;
i_m_itr
--
;
}
}
// printf("[%d] \n", __LINE__);
// printf("[%d] \n", __LINE__);
}
}
else
else
{
{
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_wo_itr
=
i_wo
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_wi_itr
=
i_wi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
ck
::
index_t
i_hi_itr
=
i_hi
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
while
(
i_m_itr
>
0
)
while
(
i_m_itr
>
0
)
{
{
/*** go along Gemm K ***/
/*** go along Gemm K ***/
const
float
*
p_src_k
=
p_src
;
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
float
*
p_dst_k
=
p_dst
;
ck
::
index_t
i_wi_itr_k
=
i_wi_itr
;
ck
::
index_t
i_wi_itr_k
=
i_wi_itr
;
ck
::
index_t
i_hi_itr_k
=
i_hi_itr
;
ck
::
index_t
i_hi_itr_k
=
i_hi_itr
;
ck
::
index_t
i_c_itr_k
=
i_c
;
ck
::
index_t
i_c_itr_k
=
i_c
;
ck
::
index_t
i_y_itr_k
=
i_y
;
ck
::
index_t
i_y_itr_k
=
i_y
;
ck
::
index_t
i_x_itr_k
=
i_x
;
ck
::
index_t
i_x_itr_k
=
i_x
;
ck
::
index_t
i_k_itr
=
k_per_block
;
ck
::
index_t
i_k_itr
=
k_per_block
;
while
(
i_k_itr
>
0
)
while
(
i_k_itr
>
0
)
{
{
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
ck
::
index_t
current_k_block
=
ck
::
math
::
min
(
C
-
i_c_itr_k
,
k_per_block
);
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
);
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
);
else
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
p_dst_k
+=
current_k_block
;
p_dst_k
+=
current_k_block
;
p_src_k
+=
current_k_block
;
p_src_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
;
i_c_itr_k
+=
current_k_block
;
if
(
i_c_itr_k
>=
C
)
if
(
i_c_itr_k
>=
C
)
{
{
i_c_itr_k
=
0
;
i_c_itr_k
=
0
;
i_x_itr_k
++
;
i_x_itr_k
++
;
i_wi_itr_k
+=
Dx
;
i_wi_itr_k
+=
Dx
;
p_src_k
+=
input_offset_ovf_c_acc_x
;
p_src_k
+=
input_offset_ovf_c_acc_x
;
}
}
if
(
i_x_itr_k
>=
Fx
)
if
(
i_x_itr_k
>=
Fx
)
{
{
i_x_itr_k
=
0
;
i_x_itr_k
=
0
;
i_y_itr_k
++
;
i_y_itr_k
++
;
i_hi_itr_k
+=
Dy
;
i_wi_itr_k
-=
Dx
*
Fx
;
p_src_k
+=
input_offset_ovf_x_acc_y
;
i_hi_itr_k
+=
Dy
;
}
p_src_k
+=
input_offset_ovf_x_acc_y
;
}
i_k_itr
-=
current_k_block
;
}
i_k_itr
-=
current_k_block
;
/*** go along Gemm K ***/
}
/*** go along Gemm K ***/
p_dst
+=
k_per_block
;
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wi_itr
+=
Sx
;
i_wo_itr
++
;
p_src
+=
input_offset_acc_wi
;
i_wi_itr
+=
Sx
;
if
(
i_wo_itr
>=
Wo
)
p_src
+=
input_offset_acc_wi
;
{
if
(
i_wo_itr
>=
Wo
)
i_wo_itr
=
0
;
{
i_wi_itr
-=
Wo
*
Sx
;
i_wo_itr
=
0
;
i_ho_itr
++
;
i_wi_itr
-=
Wo
*
Sx
;
i_hi_itr
+=
Sy
;
i_ho_itr
++
;
p_src
+=
input_offset_ovf_wi_acc_hi
;
i_hi_itr
+=
Sy
;
}
p_src
+=
input_offset_ovf_wi_acc_hi
;
}
if
(
i_ho_itr
>=
Ho
)
{
if
(
i_ho_itr
>=
Ho
)
i_ho_itr
=
0
;
{
i_hi_itr
-=
Ho
*
Sy
;
i_ho_itr
=
0
;
// i_n++;
i_hi_itr
-=
Ho
*
Sy
;
p_src
+=
input_offset_ovf_hi_acc_n
;
// i_n++;
}
p_src
+=
input_offset_ovf_hi_acc_n
;
}
i_m_itr
--
;
}
i_m_itr
--
;
}
}
}
}
}
}
}
}
}
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
{
if
constexpr
(
ConvForwardSpecialization
==
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
if
constexpr
(
ConvForwardSpecialization
==
{
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
{
i_c
+=
move_k
;
// printf(" => move_k:%d, src offset:%d\n", move_k, src_offset);
src_offset
+=
move_k
;
i_c
+=
move_k
;
}
src_offset
+=
move_k
;
else
if
constexpr
(
ConvForwardSpecialization
==
}
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
else
if
constexpr
(
ConvForwardSpecialization
==
{
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
i_c
+=
move_k
;
{
src_offset
+=
move_k
;
i_c
+=
move_k
;
}
src_offset
+=
move_k
;
else
}
{
else
if
constexpr
(
GemmKSpecialization
==
{
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
if
constexpr
(
GemmKSpecialization
==
{
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
)
// c % k_per_block == 0, so every time k_per_block here is the same
{
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// c % k_per_block == 0, so every time k_per_block here is the same
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// fflush(stdout);
// printf("222222 C:%d, src_offset:%d, i_c:%d, i_x:%d\n", C, src_offset, i_c, i_x);
// fflush(stdout);
// TODO: branch seems weird
// TODO: branch seems weird
i_c
+=
move_k
;
src_offset
+=
move_k
;
i_c
+=
move_k
;
src_offset
+=
move_k
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if
(
i_c
>=
C
)
{
if
(
i_c
>=
C
)
i_c
=
0
;
{
i_x
++
;
i_c
=
0
;
i_wi
+=
Dx
;
i_x
++
;
src_offset
+=
Dx
*
C
-
C
;
i_wi
+=
Dx
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
src_offset
+=
Dx
*
C
-
C
;
}
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
if
(
i_x
>=
Fx
)
}
{
if
(
i_x
>=
Fx
)
i_x
=
0
;
{
i_y
++
;
i_x
=
0
;
i_wi
=
i_wi
-
Fx
*
Dx
;
i_y
++
;
i_hi
+=
Dy
;
i_wi
=
i_wi
-
Fx
*
Dx
;
i_hi
+=
Dy
;
src_offset
+=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
src_offset
+=
Dy
*
Wi
*
C
-
Fx
*
Dx
*
C
;
}
// printf("3333[%d] src_offset:%d\n", __LINE__, src_offset);
}
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
// i_hi, i_wi, src_offset); fflush(stdout);
// printf("inp move:%d, i_c:%d, i_hi:%d, i_wi:%d src_offset:%d\n", move_k, i_c,
}
// i_hi, i_wi, src_offset); fflush(stdout);
else
}
{
else
i_gemm_k
+=
move_k
;
{
i_gemm_k
+=
move_k
;
i_c
=
i_gemm_k
%
C
;
i_x
=
(
i_gemm_k
/
C
)
%
Fx
;
i_c
=
i_gemm_k
%
C
;
i_y
=
(
i_gemm_k
/
C
)
/
Fx
;
i_x
=
(
i_gemm_k
/
C
)
%
Fx
;
i_y
=
(
i_gemm_k
/
C
)
/
Fx
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
i_hi
=
i_ho
*
Sy
+
i_y
*
Dy
-
Py
;
i_wi
=
i_wo
*
Sx
+
i_x
*
Dx
-
Px
;
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
}
src_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
+
i_c
;
}
}
}
}
}
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_n
;
ck
::
index_t
i_c
;
ck
::
index_t
i_n
;
ck
::
index_t
i_hi
;
ck
::
index_t
i_c
;
ck
::
index_t
i_wi
;
ck
::
index_t
i_hi
;
ck
::
index_t
i_ho
;
ck
::
index_t
i_wi
;
ck
::
index_t
i_wo
;
ck
::
index_t
i_ho
;
ck
::
index_t
i_y
;
ck
::
index_t
i_wo
;
ck
::
index_t
i_x
;
ck
::
index_t
i_y
;
ck
::
index_t
i_gemm_k
;
ck
::
index_t
i_x
;
ck
::
index_t
i_gemm_k
;
ck
::
index_t
N
;
// ck::index_t K;
ck
::
index_t
N
;
ck
::
index_t
C
;
// ck::index_t K;
ck
::
index_t
Hi
;
ck
::
index_t
C
;
ck
::
index_t
Wi
;
ck
::
index_t
Hi
;
ck
::
index_t
Ho
;
ck
::
index_t
Wi
;
ck
::
index_t
Wo
;
ck
::
index_t
Ho
;
ck
::
index_t
Wo
;
ck
::
index_t
Sy
;
ck
::
index_t
Sx
;
ck
::
index_t
Sy
;
ck
::
index_t
Sx
;
ck
::
index_t
Dy
;
ck
::
index_t
Dx
;
ck
::
index_t
Dy
;
ck
::
index_t
Dx
;
ck
::
index_t
Py
;
ck
::
index_t
Px
;
ck
::
index_t
Py
;
ck
::
index_t
Px
;
ck
::
index_t
Fy
;
ck
::
index_t
Fx
;
ck
::
index_t
Fy
;
ck
::
index_t
Fx
;
intptr_t
input_offset_acc_wi
;
intptr_t
input_offset_ovf_wi_acc_hi
;
intptr_t
input_offset_acc_wi
;
intptr_t
input_offset_ovf_hi_acc_n
;
intptr_t
input_offset_ovf_wi_acc_hi
;
intptr_t
input_offset_ovf_hi_acc_n
;
// intptr_t input_offset_acc_c;
intptr_t
input_offset_ovf_c_acc_x
;
// intptr_t input_offset_acc_c;
intptr_t
input_offset_ovf_x_acc_y
;
intptr_t
input_offset_ovf_c_acc_x
;
intptr_t
input_offset_ovf_x_acc_y
;
intptr_t
src_offset
;
// keep this as pointer type in case we have negative offset
};
intptr_t
src_offset
;
// keep this as pointer type in case we have negative offset
};
template
<
typename
SrcData
,
typename
DstData
,
template
<
typename
SrcData
,
typename
SrcDesc
,
typename
DstData
,
typename
DstDesc
,
typename
SrcDesc
,
typename
ElementwiseOperation
,
typename
DstDesc
,
bool
BypassTransfer
,
typename
ElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
bool
BypassTransfer
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
{
struct
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
{
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
// using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
// using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
(
const
SrcDesc
&
src_desc
,
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
(
const
Index
&
src_slice_origin
,
const
SrcDesc
&
src_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
dst_slice_origin
,
const
DstDesc
&
dst_desc
,
const
ElementwiseOperation
&
element_op
)
const
Index
&
dst_slice_origin
,
:
element_op_
(
element_op
)
const
ElementwiseOperation
&
element_op
)
{
:
element_op_
(
element_op
)
GemmN1
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
{
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
GemmN1
=
src_desc
.
GetTransforms
()[
Number
<
3
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
GemmN
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
}
GemmK
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
ck
::
index_t
idx_n0
=
src_slice_origin_idx
[
Number
<
0
>
{}];
{
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_n0
=
src_slice_origin_idx
[
Number
<
0
>
{}];
ck
::
index_t
idx_n1
=
src_slice_origin_idx
[
Number
<
2
>
{}];
ck
::
index_t
idx_k
=
src_slice_origin_idx
[
Number
<
1
>
{}];
ck
::
index_t
idx_n1
=
src_slice_origin_idx
[
Number
<
2
>
{}];
i_gemm_n
=
idx_n0
*
GemmN1
+
idx_n1
;
// i_gemm_k = idx_k;
i_gemm_n
=
idx_n0
*
GemmN1
+
idx_n1
;
// i_gemm_k = idx_k;
src_offset
=
idx_n0
*
GemmK
*
GemmN1
+
idx_k
+
idx_n1
*
GemmN1
;
// Note we transpose here
src_offset
=
idx_n0
*
GemmK
*
GemmN1
+
idx_k
+
idx_n1
*
GemmN1
;
// Note we transpose here
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
// src_offset);
// printf("xxxx i_gemm_n:%d, i_gemm_k:%d, src_offset:%d\n", i_gemm_n, i_gemm_k,
}
// src_offset);
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
Run
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
{
void
Run
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
if
constexpr
(
BypassTransfer
)
{
{
if
constexpr
(
BypassTransfer
)
// TODO: weight NHWC not support this
{
}
// TODO: weight NHWC not support this
else
}
{
else
const
ck
::
index_t
n_per_block
=
{
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
const
ck
::
index_t
n_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}]
*
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
2
>
{}];
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
k_per_block
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// dst_desc.GetTransforms()[Number<0>{}]
// printf(" >>>> %d, %d, %d -> %d(%dx%d), %d\n", GemmN, GemmK, GemmN1, n_per_block,
// .GetUpperLengths()[Number<0>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// dst_desc.GetTransforms()[Number<0>{}]
// .GetUpperLengths()[Number<0>{}],
// .GetUpperLengths()[Number<2>{}],
// dst_desc.GetTransforms()[Number<0>{}]
// k_per_block);
// .GetUpperLengths()[Number<2>{}],
// k_per_block);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
);
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
for
(
index_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
// n * k -> n0 * k * n1, n1 = 8, n0 = n/8
{
for
(
index_t
i_n_itr
=
0
;
i_n_itr
<
n_per_block
;
i_n_itr
+=
8
)
ck
::
index_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
8
);
{
ck
::
index_t
i_k_itr
=
k_per_block
;
ck
::
index_t
current_n_8
=
ck
::
math
::
min
(
GemmN
-
(
i_n_itr
+
i_gemm_n
),
8
);
if
(
current_n_8
==
8
)
ck
::
index_t
i_k_itr
=
k_per_block
;
{
if
(
current_n_8
==
8
)
const
float
*
p_src_k
=
p_src
;
{
float
*
p_dst_k
=
p_dst
;
const
float
*
p_src_k
=
p_src
;
while
(
i_k_itr
>=
8
)
float
*
p_dst_k
=
p_dst
;
{
while
(
i_k_itr
>=
8
)
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
);
{
p_dst_k
+=
8
*
8
;
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
);
p_src_k
+=
8
;
p_dst_k
+=
8
*
8
;
i_k_itr
-=
8
;
p_src_k
+=
8
;
}
i_k_itr
-=
8
;
if
(
i_k_itr
&
4
)
}
{
if
(
i_k_itr
&
4
)
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
{
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
2
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
2
];
p_dst_k
[
3
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
3
];
p_dst_k
+=
4
*
8
;
p_src_k
+=
4
;
p_dst_k
+=
4
*
8
;
}
p_src_k
+=
4
;
if
(
i_k_itr
&
2
)
}
{
if
(
i_k_itr
&
2
)
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
{
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
+=
2
*
8
;
p_src_k
+=
2
;
p_dst_k
+=
2
*
8
;
}
p_src_k
+=
2
;
if
(
i_k_itr
&
1
)
}
{
if
(
i_k_itr
&
1
)
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
{
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
}
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
}
}
else
}
{
else
const
float
*
p_src_k
=
p_src
;
{
float
*
p_dst_k
=
p_dst
;
const
float
*
p_src_k
=
p_src
;
float
*
p_dst_k
=
p_dst
;
for
(
index_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
{
for
(
index_t
i_sub_n
=
0
;
i_sub_n
<
8
;
i_sub_n
++
)
for
(
index_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
{
{
for
(
index_t
i_sub_k
=
0
;
i_sub_k
<
k_per_block
;
i_sub_k
++
)
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
{
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
float
v
=
i_current_n_itr
<
GemmN
?
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
]
:
.0
f
;
float
v
=
i_current_n_itr
<
GemmN
?
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
]
:
.0
f
;
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
}
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
}
}
}
}
}
p_dst
+=
8
*
k_per_block
;
p_src
+=
8
*
GemmK
;
p_dst
+=
8
*
k_per_block
;
}
p_src
+=
8
*
GemmK
;
}
}
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
{
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
{
ck
::
index_t
move_n0
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
ck
::
index_t
move_k
=
src_slice_origin_step_idx
[
Number
<
1
>
{}];
ck
::
index_t
move_n0
=
src_slice_origin_step_idx
[
Number
<
0
>
{}];
// i_gemm_k += move_k;
// i_gemm_k += move_k;
// printf("wei move:%d\n", move_k); fflush(stdout);
// printf("wei move:%d\n", move_k); fflush(stdout);
src_offset
+=
move_k
+
move_n0
*
GemmK
*
GemmN1
;
}
src_offset
+=
move_k
+
move_n0
*
GemmK
*
GemmN1
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_gemm_n
;
// ck::index_t i_gemm_k;
ck
::
index_t
i_gemm_n
;
// ck::index_t i_gemm_k;
// ck::index_t GemmN0;
ck
::
index_t
GemmN1
;
// ck::index_t GemmN0;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmN1
;
ck
::
index_t
GemmK
;
ck
::
index_t
GemmN
;
ck
::
index_t
GemmK
;
intptr_t
src_offset
;
};
intptr_t
src_offset
;
};
template
<
typename
SrcData
,
typename
DstData
,
template
<
typename
SrcData
,
typename
SrcDesc
,
typename
DstData
,
typename
DstDesc
,
typename
SrcDesc
,
typename
ElementwiseOperation
,
typename
DstDesc
,
bool
BypassTransfer
,
typename
ElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
bool
BypassTransfer
,
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
struct
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
ConvolutionForwardGemmKSpecialization_t
GemmKSpecialization
>
{
struct
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
{
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
(
const
SrcDesc
&
src_desc
,
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
(
const
Index
&
,
const
SrcDesc
&
src_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
,
const
Index
&
,
const
DstDesc
&
dst_desc
,
const
ElementwiseOperation
&
element_op
)
const
Index
&
,
:
element_op_
(
element_op
)
const
ElementwiseOperation
&
element_op
)
{
:
element_op_
(
element_op
)
DstGemmM
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
{
DstGemmN
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
DstGemmM
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
0
>
{}];
DstGemmN
=
dst_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
src_offset
=
0
;
dst_offset
=
0
;
src_offset
=
0
;
}
dst_offset
=
0
;
}
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
{
void
SetSrcSliceOrigin
(
const
SrcDesc
&
,
const
Index
&
src_slice_origin_idx
)
if
constexpr
(
BypassTransfer
)
{
{
if
constexpr
(
BypassTransfer
)
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
{
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
auto
i_src_gemm_m
=
src_slice_origin_idx
[
Number
<
0
>
{}];
auto
i_src_gemm_n
=
src_slice_origin_idx
[
Number
<
1
>
{}];
src_offset
=
i_src_gemm_m
*
DstGemmN
+
i_src_gemm_n
;
}
src_offset
=
i_src_gemm_m
*
DstGemmN
+
i_src_gemm_n
;
}
}
}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
dst_slice_origin_idx
)
{
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
dst_slice_origin_idx
)
i_dst_gemm_m
=
dst_slice_origin_idx
[
Number
<
0
>
{}];
{
i_dst_gemm_n
=
dst_slice_origin_idx
[
Number
<
1
>
{}];
i_dst_gemm_m
=
dst_slice_origin_idx
[
Number
<
0
>
{}];
i_dst_gemm_n
=
dst_slice_origin_idx
[
Number
<
1
>
{}];
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
void
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
Run
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
void
{
Run
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
if
constexpr
(
BypassTransfer
)
{
{
if
constexpr
(
BypassTransfer
)
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
{
}
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
else
}
{
else
const
ck
::
index_t
m_per_block
=
{
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
const
ck
::
index_t
m_per_block
=
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
src_desc
.
GetTransforms
()[
Number
<
0
>
{}]
const
ck
::
index_t
n_per_block
=
.
GetUpperLengths
()[
Number
<
0
>
{}];
// must be multiple of 8
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
n_per_block
=
src_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
const
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
const
float
*
p_src
=
reinterpret_cast
<
float
*>
(
src_buf
.
p_data_
)
+
src_offset
;
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block);fflush(stdout);
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
// standard 8-4-2-1 pattern
{
while
(
i_m_itr
>=
8
)
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
i_m_itr
-=
8
;
p_src
+=
8
*
n_per_block
;
p_dst
+=
8
*
DstGemmN
;
}
p_src
+=
8
*
n_per_block
;
}
if
(
i_m_itr
&
4
)
{
if
(
i_m_itr
&
4
)
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
p_dst
+=
4
*
DstGemmN
;
}
p_src
+=
4
*
n_per_block
;
}
if
(
i_m_itr
&
2
)
{
if
(
i_m_itr
&
2
)
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
p_dst
+=
2
*
DstGemmN
;
}
p_src
+=
2
*
n_per_block
;
}
if
(
i_m_itr
&
1
)
{
if
(
i_m_itr
&
1
)
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
{
}
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
}
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
Index
&
)
{}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
void
MoveDstSliceWindow
(
const
DstDesc
&
,
const
Index
&
)
{}
private:
const
ElementwiseOperation
element_op_
;
private:
const
ElementwiseOperation
element_op_
;
ck
::
index_t
i_dst_gemm_m
;
ck
::
index_t
i_dst_gemm_n
;
ck
::
index_t
i_dst_gemm_m
;
ck
::
index_t
i_dst_gemm_n
;
ck
::
index_t
DstGemmM
;
ck
::
index_t
DstGemmN
;
ck
::
index_t
DstGemmM
;
ck
::
index_t
DstGemmN
;
intptr_t
src_offset
;
intptr_t
dst_offset
;
intptr_t
src_offset
;
};
intptr_t
dst_offset
;
};
}
// namespace cpu
}
// namespace ck
}
// namespace cpu
}
// namespace ck
#endif
#endif
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
5771a040
#include <stdlib.h>
#include <stdlib.h>
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation_cpu.hpp"
#include "element_wise_operation_cpu.hpp"
#include "device_operation_instance.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
cpu
{
namespace
device
{
namespace
device
{
namespace
device_conv2d_fwd_avx2_instance
{
namespace
device_conv2d_fwd_avx2_instance
{
using
InType
=
float
;
using
InType
=
float
;
using
WeiType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
using
AccType
=
float
;
using
InLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// NHWC
using
InLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// NHWC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
WeiType
,
OutType
,
OutType
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
NonTemporalStore
>
;
NonTemporalStore
>
;
static
constexpr
auto
ConvFwdDefault
=
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
// clang-format off
// clang-format off
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
#define DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(a_elem_op, b_elem_op, c_elem_op, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf) \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MNK
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MNK, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
GemmKLoopOverC
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, GemmKLoopOverC , LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwdDefault
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
,
\
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwdDefault, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>, \
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
a_elem_op
,
b_elem_op
,
c_elem_op
,
ConvFwd1x1P0
,
DefaultGemmKLoop
,
LoopOver_MKN
,
2
,
m_per_block
,
n_per_block
,
k_per_block
,
m_per_thread
,
n_per_thread
,
a_local_buf
,
b_local_buf
,
c_local_buf
>
DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<float , float , float, a_elem_op, b_elem_op, c_elem_op, ConvFwd1x1P0, DefaultGemmKLoop, LoopOver_MKN, 2, m_per_block, n_per_block, k_per_block, m_per_thread, n_per_thread, a_local_buf, b_local_buf, c_local_buf>
// clang-format on
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
120
,
64
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
256
,
120
,
64
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
144
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
144
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 768, 192, 128, 4, 24, true, true, false),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
288
,
128
,
4
,
24
,
true
,
true
,
false
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
768
,
288
,
128
,
4
,
24
,
true
,
true
,
false
)
>
;
// clang-format on
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_instances
{});
}
}
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
5771a040
...
@@ -37,26 +37,53 @@ using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
...
@@ -37,26 +37,53 @@ using WeiElementOp = ck::tensor_operation::cpu::element_wise::PassThrough;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
template
<
typename
T
>
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
,
double
nrms
,
int
per_pixel_check
=
0
)
{
{
int
error_count
=
0
;
int
error_count
=
0
;
float
max_diff
=
1e-6
;
float
max_diff
=
1e-5
;
double
square_difference
=
.0
;
double
mag1
=
.0
;
double
mag2
=
.0
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
double
ri
=
(
double
)
ref
.
mData
[
i
];
if
(
max_diff
<
diff
)
double
pi
=
(
double
)
result
.
mData
[
i
];
double
d
=
ri
-
pi
;
if
(
per_pixel_check
)
{
{
error_count
++
;
if
(
max_diff
<
std
::
abs
(
d
))
printf
(
"idx:%3d, ref:%f, res:%f (diff:%f)
\n
"
,
{
i
,
error_count
++
;
double
(
ref
.
mData
[
i
]),
printf
(
"idx:%3d, ref:%f, res:%f (diff:%f)
\n
"
,
double
(
result
.
mData
[
i
]),
i
,
diff
);
double
(
ref
.
mData
[
i
]),
double
(
result
.
mData
[
i
]),
d
);
}
}
}
square_difference
+=
d
*
d
;
if
(
std
::
abs
(
mag1
)
<
std
::
abs
(
ri
))
mag1
=
ri
;
if
(
std
::
abs
(
mag2
)
<
std
::
abs
(
pi
))
mag2
=
pi
;
}
}
return
error_count
==
0
;
double
mag
=
std
::
max
({
std
::
fabs
(
mag1
),
std
::
fabs
(
mag2
),
std
::
numeric_limits
<
double
>::
min
()});
double
computed_nrms
=
std
::
sqrt
(
square_difference
)
/
(
std
::
sqrt
(
ref
.
mData
.
size
())
*
mag
);
if
(
computed_nrms
>=
nrms
)
printf
(
"nrms:%lf, mag1:%lf, mag2:%lf, expected_nrms is %1f
\n
"
,
computed_nrms
,
mag1
,
mag2
,
nrms
);
return
computed_nrms
<
nrms
&&
error_count
==
0
;
}
}
float
calculate_gflops
()
{}
float
calculate_gflops
()
{}
...
@@ -171,20 +198,28 @@ int main(int argc, char* argv[])
...
@@ -171,20 +198,28 @@ int main(int argc, char* argv[])
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Dilation(H, W):"
<<
conv_dilation_h
<<
", "
<<
conv_dilation_w
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
<<
", Threads:"
<<
omp_get_max_threads
()
<<
std
::
endl
;
int
per_pixel_check
=
0
;
switch
(
init_method
)
switch
(
init_method
)
{
{
case
0
:
break
;
case
0
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{});
per_pixel_check
=
1
;
break
;
case
1
:
case
1
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
// in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
// wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
per_pixel_check
=
1
;
break
;
break
;
case
2
:
case
2
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_
1
<
InDataType
>
{});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_
3
<
InDataType
>
{
0.0
,
1.0
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_
1
<
WeiDataType
>
{});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_
3
<
WeiDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
case
3
:
#define PACK_32(v24, v16, v8, v0) \
#define PACK_32(v24, v16, v8, v0) \
...
@@ -310,7 +345,10 @@ int main(int argc, char* argv[])
...
@@ -310,7 +345,10 @@ int main(int argc, char* argv[])
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_n_k_ho_wo_device_result
.
mData
.
data
());
if
(
!
check_out
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
))
if
(
!
check_out
(
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_device_result
,
1e-6
,
per_pixel_check
))
{
{
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"Fail Info: "
<<
conv_ptr
->
GetTypeString
()
<<
std
::
endl
;
success
=
false
;
success
=
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment