Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
090ba885
Commit
090ba885
authored
May 15, 2022
by
carlushuang
Browse files
add elementwise fusion support
parent
8ce9fe57
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
842 additions
and
510 deletions
+842
-510
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp
...nsor_operation/cpu/element/element_wise_operation_cpu.hpp
+342
-370
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+109
-14
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+253
-124
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+76
-1
profiler/include/profile_conv_fwd_cpu_impl.hpp
profiler/include/profile_conv_fwd_cpu_impl.hpp
+9
-0
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+48
-1
No files found.
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
090ba885
...
@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
...
@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp
View file @
090ba885
...
@@ -11,90 +11,118 @@ using float4_t = ck::cpu::float4_t;
...
@@ -11,90 +11,118 @@ using float4_t = ck::cpu::float4_t;
struct
PassThrough
struct
PassThrough
{
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
x
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
x
;
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
Apply
(
x
);
}
float
Apply
(
const
float
&
x
)
const
{
return
x
;
}
float4_t
Apply
(
const
float4_t
&
x
)
const
{
return
x
;
}
float8_t
Apply
(
const
float8_t
&
x
)
const
{
return
x
;
}
static
constexpr
char
*
Name
()
{
return
"PassThrough"
;
}
};
};
struct
Add
struct
Add
{
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
x0
+
x1
;
}
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
)
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
{
y
=
_mm_add_ps
(
x0
,
x1
);
y
=
Apply
(
x0
,
x1
);
}
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
{
y
=
_mm256_add_ps
(
x0
,
x1
);
y
=
Apply
(
x0
,
x1
);
}
}
float
Apply
(
const
float
&
x0
,
const
float
&
x1
)
const
{
return
x0
+
x1
;
}
float4_t
Apply
(
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
return
_mm_add_ps
(
x0
,
x1
);
}
float8_t
Apply
(
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
return
_mm256_add_ps
(
x0
,
x1
);
}
static
constexpr
char
*
Name
()
{
return
"Add"
;
}
};
struct
Relu
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
Apply
(
x
);
}
float
Apply
(
const
float
&
x
)
const
{
return
x
>
0
?
x
:
0
;
}
float4_t
Apply
(
const
float4_t
&
x
)
const
{
return
_mm_max_ps
(
x
,
_mm_setzero_ps
());
}
float8_t
Apply
(
const
float8_t
&
x
)
const
{
return
_mm256_max_ps
(
x
,
_mm256_setzero_ps
());
}
static
constexpr
char
*
Name
()
{
return
"Relu"
;
}
};
};
struct
AlphaBetaAdd
struct
AlphaBetaAdd
{
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{}
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{}
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
{
y
=
alpha_
*
x0
+
beta_
*
x1
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
{
y
=
_mm_add_ps
(
_mm_mul_ps
(
x0
,
_mm_set1_ps
(
alpha_
)),
_mm_mul_ps
(
x1
,
_mm_set1_ps
(
beta_
))
);
y
=
Apply
(
x0
,
x1
);
}
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
{
y
=
_mm256_add_ps
(
_mm256_mul_ps
(
x0
,
_mm256_set1_ps
(
alpha_
)),
y
=
Apply
(
x0
,
x1
);
}
float
Apply
(
const
float
&
x0
,
const
float
&
x1
)
const
{
return
alpha_
*
x0
+
beta_
*
x1
;
}
float4_t
Apply
(
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
return
_mm_add_ps
(
_mm_mul_ps
(
x0
,
_mm_set1_ps
(
alpha_
)),
_mm_mul_ps
(
x1
,
_mm_set1_ps
(
beta_
)));
}
float8_t
Apply
(
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
return
_mm256_add_ps
(
_mm256_mul_ps
(
x0
,
_mm256_set1_ps
(
alpha_
)),
_mm256_mul_ps
(
x1
,
_mm256_set1_ps
(
beta_
)));
_mm256_mul_ps
(
x1
,
_mm256_set1_ps
(
beta_
)));
}
}
static
constexpr
char
*
Name
()
{
return
"AlphaBetaAdd"
;
}
float
alpha_
;
float
alpha_
;
float
beta_
;
float
beta_
;
};
};
struct
AddRelu
struct
AddRelu
{
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
{
const
float
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
{
y
=
_mm_max_ps
(
_mm_add_ps
(
x0
,
x1
),
_mm_setzero_ps
());
y
=
Apply
(
x0
,
x1
);
}
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
{
y
=
_mm256_max_ps
(
_mm256_add_ps
(
x0
,
x1
),
_mm256_setzero_ps
());
y
=
Apply
(
x0
,
x1
);
}
}
};
#if 0
float
Apply
(
const
float
&
x0
,
const
float
&
x1
)
const
struct AddHardswish
{
void operator()(float& y, const float& x0, const float& x1) const
{
{
float a = x0 + x1;
const
float
a
=
x0
+
x1
;
float b = a + float{3};
return
a
>
0
?
a
:
0
;
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
}
void
float4_t
Apply
(
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
{
float a = x0 + x1;
return
_mm_max_ps
(
_mm_add_ps
(
x0
,
x1
),
_mm_setzero_ps
());
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
}
float8_t
Apply
(
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
return
_mm256_max_ps
(
_mm256_add_ps
(
x0
,
x1
),
_mm256_setzero_ps
());
}
static
constexpr
char
*
Name
()
{
return
"AddRelu"
;
}
};
};
#endif
struct
AddReluAdd
struct
AddReluAdd
{
{
...
@@ -119,65 +147,9 @@ struct AddReluAdd
...
@@ -119,65 +147,9 @@ struct AddReluAdd
float8_t
b
=
_mm256_max_ps
(
a
,
_mm256_setzero_ps
());
float8_t
b
=
_mm256_max_ps
(
a
,
_mm256_setzero_ps
());
y
=
_mm256_add_ps
(
b
,
x2
);
y
=
_mm256_add_ps
(
b
,
x2
);
}
}
};
#if 0
struct AddHardswishAdd
{
void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
};
#endif
#if 0
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
void operator()(int8_t& y, const int& x) const
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
static
constexpr
char
*
Name
()
{
return
"AddReluAdd"
;
}
float scaleRelu_;
};
};
#endif
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
...
...
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
090ba885
...
@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
...
@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
}
static
auto
GetAMultiIndex
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
m_per_blk
,
k_per_blk
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
k_per_blk
,
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
));
}
}
static
auto
GetBMultiIndex
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
k_per_blk
,
math
::
integer_least_multiple
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
// B : N/8, K, N8
return
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
}
static
auto
GetCMultiIndex
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
const
CGridDesc
&
c_grid_desc
)
...
@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
...
@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
//
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
//
ck::make_multi_index(i_mc, i_nc));
}
}
else
else
{
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
RunRead
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
}
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
...
@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
...
@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
a_grid_buf
,
a_block_desc
,
a_block_buf
,
GetAMultiIndex
(
mc_size
,
kc_size
));
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
,
GetBMultiIndex
(
kc_size
,
nc_size
));
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
a_block_buf
,
...
@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
...
@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
}
}
}
}
if
constexpr
(
UseCLocalBuffer
)
// if constexpr(UseCLocalBuffer)
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
}
}
}
}
}
...
@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
...
@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
,
GetAMultiIndex
(
mc_size
,
kc_size
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
...
@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
...
@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
b_threadwise_copy
.
Run
(
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
b_grid_buf
,
b_block_desc
,
b_block_buf
,
GetBMultiIndex
(
kc_size
,
nc_size
));
auto
c_block_desc
=
UseCLocalBuffer
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
...
@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
...
@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
{
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_threadwise_copy
.
RunRead
(
c_block_desc
,
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
}
blockwise_gemm
.
Run
(
a_block_desc
,
blockwise_gemm
.
Run
(
a_block_desc
,
...
@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
...
@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
i_kc
!=
0
);
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
if
((
i_nc
+
n_per_block
)
<
GemmN
)
{
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
if
constexpr
(
UseCLocalBuffer
)
if
constexpr
(
UseCLocalBuffer
)
{
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if
((
i_kc
+
k_per_block
)
>=
GemmK
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
}
}
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
090ba885
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
#include "tensor_space_filling_curve.hpp"
#include "dynamic_buffer_cpu.hpp"
#include "dynamic_buffer_cpu.hpp"
#include
<immintrin.h>
#include
"element_wise_operation_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include "convolution_forward_specialization_cpu.hpp"
#include <immintrin.h>
#include <immintrin.h>
...
@@ -17,7 +17,8 @@ namespace cpu {
...
@@ -17,7 +17,8 @@ namespace cpu {
namespace
avx2_util
{
namespace
avx2_util
{
inline
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
)
template
<
typename
ElementwiseOp
>
void
memcpy32_avx2
(
void
*
dst
,
const
void
*
src
,
const
ck
::
index_t
n
,
const
ElementwiseOp
&
element_op
)
{
{
// 16-8-4-2-1 pattern
// 16-8-4-2-1 pattern
ck
::
index_t
i_n
=
n
;
ck
::
index_t
i_n
=
n
;
...
@@ -25,33 +26,33 @@ inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
...
@@ -25,33 +26,33 @@ inline void memcpy32_avx2(void* dst, const void* src, const ck::index_t n)
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
while
(
i_n
>=
16
)
while
(
i_n
>=
16
)
{
{
_mm256_storeu_ps
(
p_dst
+
0
,
_mm256_loadu_ps
(
p_src
+
0
));
_mm256_storeu_ps
(
p_dst
+
0
,
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
0
))
)
;
_mm256_storeu_ps
(
p_dst
+
8
,
_mm256_loadu_ps
(
p_src
+
8
));
_mm256_storeu_ps
(
p_dst
+
8
,
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
8
))
)
;
p_dst
+=
16
;
p_dst
+=
16
;
p_src
+=
16
;
p_src
+=
16
;
i_n
-=
16
;
i_n
-=
16
;
}
}
if
(
i_n
&
8
)
if
(
i_n
&
8
)
{
{
_mm256_storeu_ps
(
p_dst
,
_mm256_loadu_ps
(
p_src
));
_mm256_storeu_ps
(
p_dst
,
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
))
)
;
p_dst
+=
8
;
p_dst
+=
8
;
p_src
+=
8
;
p_src
+=
8
;
}
}
if
(
i_n
&
4
)
if
(
i_n
&
4
)
{
{
_mm_storeu_ps
(
p_dst
,
_mm_loadu_ps
(
p_src
));
_mm_storeu_ps
(
p_dst
,
element_op
.
Apply
(
_mm_loadu_ps
(
p_src
))
)
;
p_dst
+=
4
;
p_dst
+=
4
;
p_src
+=
4
;
p_src
+=
4
;
}
}
if
(
i_n
&
2
)
if
(
i_n
&
2
)
{
{
_mm_storeu_si64
(
p_dst
,
_mm_loadu_si64
(
p_src
));
_mm_storeu_si64
(
p_dst
,
element_op
.
Apply
(
_mm_loadu_si64
(
p_src
))
)
;
p_dst
+=
2
;
p_dst
+=
2
;
p_src
+=
2
;
p_src
+=
2
;
}
}
if
(
i_n
&
1
)
if
(
i_n
&
1
)
{
{
*
p_dst
=
*
p_src
;
*
p_dst
=
element_op
.
Apply
(
*
p_src
)
;
}
}
}
}
...
@@ -90,8 +91,12 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
...
@@ -90,8 +91,12 @@ inline void memset32_avx2(void* dst, const int32_t value, const ck::index_t n)
}
}
}
}
inline
void
template
<
typename
ElementwiseOp
>
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
)
void
transpose8x8_avx2
(
void
*
dst
,
ck
::
index_t
stride_dst
,
const
void
*
src
,
ck
::
index_t
stride_src
,
const
ElementwiseOp
&
element_op
)
{
{
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
// TODO: use vinsertf128 for better port usage. vpermf128 is slow
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
__m256
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
...
@@ -100,14 +105,14 @@ transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_
...
@@ -100,14 +105,14 @@ transpose8x8_avx2(void* dst, ck::index_t stride_dst, const void* src, ck::index_
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
const
float
*
p_src
=
reinterpret_cast
<
const
float
*>
(
src
);
r0
=
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
);
r0
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
0
*
stride_src
)
)
;
r1
=
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
);
r1
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
1
*
stride_src
)
)
;
r2
=
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
);
r2
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
2
*
stride_src
)
)
;
r3
=
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
);
r3
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
3
*
stride_src
)
)
;
r4
=
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
);
r4
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
4
*
stride_src
)
)
;
r5
=
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
);
r5
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
5
*
stride_src
)
)
;
r6
=
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
);
r6
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
6
*
stride_src
)
)
;
r7
=
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
);
r7
=
element_op
.
Apply
(
_mm256_loadu_ps
(
p_src
+
7
*
stride_src
)
)
;
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t0
=
_mm256_unpacklo_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
t1
=
_mm256_unpackhi_ps
(
r0
,
r1
);
...
@@ -354,11 +359,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -354,11 +359,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
Run
(
const
SrcDesc
&
src_desc
,
void
Run
Read
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
...
@@ -385,14 +391,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -385,14 +391,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
// standard 8-4-2-1 pattern
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
while
(
i_m_itr
>=
8
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
);
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
);
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
k_per_block
,
p_src
+
4
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
k_per_block
,
p_src
+
5
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
k_per_block
,
p_src
+
6
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
k_per_block
,
p_src
+
7
*
C
,
k_per_block
,
element_op_
);
i_m_itr
-=
8
;
i_m_itr
-=
8
;
p_dst
+=
8
*
k_per_block
;
p_dst
+=
8
*
k_per_block
;
...
@@ -400,10 +414,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -400,10 +414,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
}
}
if
(
i_m_itr
&
4
)
if
(
i_m_itr
&
4
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
);
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
k_per_block
,
p_src
+
2
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
k_per_block
,
p_src
+
3
*
C
,
k_per_block
,
element_op_
);
p_dst
+=
4
*
k_per_block
;
p_dst
+=
4
*
k_per_block
;
p_src
+=
4
*
C
;
p_src
+=
4
*
C
;
...
@@ -411,8 +429,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -411,8 +429,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if
(
i_m_itr
&
2
)
if
(
i_m_itr
&
2
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
);
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
k_per_block
,
p_src
+
1
*
C
,
k_per_block
,
element_op_
);
p_dst
+=
2
*
k_per_block
;
p_dst
+=
2
*
k_per_block
;
p_src
+=
2
*
C
;
p_src
+=
2
*
C
;
...
@@ -420,7 +440,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -420,7 +440,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if
(
i_m_itr
&
1
)
if
(
i_m_itr
&
1
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
k_per_block
,
p_src
+
0
*
C
,
k_per_block
,
element_op_
);
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
...
@@ -431,7 +452,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -431,7 +452,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
ck
::
index_t
i_ho_itr
=
i_ho
;
ck
::
index_t
i_ho_itr
=
i_ho
;
while
(
i_m_itr
>
0
)
while
(
i_m_itr
>
0
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
,
element_op_
);
p_dst
+=
k_per_block
;
p_dst
+=
k_per_block
;
i_wo_itr
++
;
i_wo_itr
++
;
p_src
+=
input_offset_acc_wi
;
p_src
+=
input_offset_acc_wi
;
...
@@ -468,7 +489,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -468,7 +489,7 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
{
{
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
);
avx2_util
::
memcpy32_avx2
(
p_dst
,
p_src
,
k_per_block
,
element_op_
);
else
else
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
avx2_util
::
memset32_avx2
(
p_dst
,
0
,
k_per_block
);
...
@@ -523,7 +544,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
...
@@ -523,7 +544,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
if
((
*
reinterpret_cast
<
uint32_t
*>
(
&
i_hi_itr_k
)
<
Hi
)
&&
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
(
*
reinterpret_cast
<
uint32_t
*>
(
&
i_wi_itr_k
)
<
Wi
))
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
);
avx2_util
::
memcpy32_avx2
(
p_dst_k
,
p_src_k
,
current_k_block
,
element_op_
);
else
else
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
avx2_util
::
memset32_avx2
(
p_dst_k
,
0
,
current_k_block
);
...
@@ -730,8 +752,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
...
@@ -730,8 +752,12 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
void
SetDstSliceOrigin
(
const
DstDesc
&
,
const
Index
&
)
{}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
Run
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
void
RunRead
(
const
SrcDesc
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
...
@@ -766,85 +792,85 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
...
@@ -766,85 +792,85 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
float
*
p_dst_k
=
p_dst
;
float
*
p_dst_k
=
p_dst
;
while
(
i_k_itr
>=
8
)
while
(
i_k_itr
>=
8
)
{
{
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
);
avx2_util
::
transpose8x8_avx2
(
p_dst_k
,
8
,
p_src_k
,
GemmK
,
element_op_
);
p_dst_k
+=
8
*
8
;
p_dst_k
+=
8
*
8
;
p_src_k
+=
8
;
p_src_k
+=
8
;
i_k_itr
-=
8
;
i_k_itr
-=
8
;
}
}
if
(
i_k_itr
&
4
)
if
(
i_k_itr
&
4
)
{
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
0
]
)
;
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
1
]
)
;
p_dst_k
[
2
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
2
]
)
;
p_dst_k
[
2
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
2
];
p_dst_k
[
2
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
2
]
)
;
p_dst_k
[
3
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
3
]
)
;
p_dst_k
[
3
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
3
];
p_dst_k
[
3
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
3
]
)
;
p_dst_k
+=
4
*
8
;
p_dst_k
+=
4
*
8
;
p_src_k
+=
4
;
p_src_k
+=
4
;
}
}
if
(
i_k_itr
&
2
)
if
(
i_k_itr
&
2
)
{
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
0
]
)
;
p_dst_k
[
1
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
1
]
)
;
p_dst_k
[
1
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
1
];
p_dst_k
[
1
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
1
]
)
;
p_dst_k
+=
2
*
8
;
p_dst_k
+=
2
*
8
;
p_src_k
+=
2
;
p_src_k
+=
2
;
}
}
if
(
i_k_itr
&
1
)
if
(
i_k_itr
&
1
)
{
{
p_dst_k
[
0
*
8
+
0
]
=
p_src_k
[
0
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
0
]
=
element_op_
.
Apply
(
p_src_k
[
0
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
1
]
=
p_src_k
[
1
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
1
]
=
element_op_
.
Apply
(
p_src_k
[
1
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
2
]
=
p_src_k
[
2
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
2
]
=
element_op_
.
Apply
(
p_src_k
[
2
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
3
]
=
p_src_k
[
3
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
3
]
=
element_op_
.
Apply
(
p_src_k
[
3
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
4
]
=
p_src_k
[
4
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
4
]
=
element_op_
.
Apply
(
p_src_k
[
4
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
5
]
=
p_src_k
[
5
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
5
]
=
element_op_
.
Apply
(
p_src_k
[
5
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
6
]
=
p_src_k
[
6
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
6
]
=
element_op_
.
Apply
(
p_src_k
[
6
*
GemmK
+
0
]
)
;
p_dst_k
[
0
*
8
+
7
]
=
p_src_k
[
7
*
GemmK
+
0
];
p_dst_k
[
0
*
8
+
7
]
=
element_op_
.
Apply
(
p_src_k
[
7
*
GemmK
+
0
]
)
;
}
}
}
}
else
else
...
@@ -858,8 +884,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
...
@@ -858,8 +884,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_NHWC
{
{
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
ck
::
index_t
i_current_n_itr
=
i_n_itr
+
i_sub_n
+
i_gemm_n
;
float
v
=
float
v
=
i_current_n_itr
<
GemmN
i_current_n_itr
<
GemmN
?
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
]
:
.0
f
;
?
element_op_
.
Apply
(
p_src_k
[
i_sub_n
*
GemmK
+
i_sub_k
])
:
.0
f
;
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
p_dst_k
[
i_sub_k
*
8
+
i_sub_n
]
=
v
;
}
}
...
@@ -949,14 +976,101 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -949,14 +976,101 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
dst_offset
=
i_dst_gemm_m
*
DstGemmN
+
i_dst_gemm_n
;
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
void
RunRead
(
const
SrcDesc
&
src_desc
,
Run
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
{
if
constexpr
(
BypassTransfer
)
if
constexpr
(
BypassTransfer
)
{
{
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
src_buf
.
p_data_
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
src_offset
;
}
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
SliceLengths
>
void
RunWrite
(
const
SrcDesc
&
src_desc
,
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
const
SliceLengths
&
slice_length
)
{
if
constexpr
(
BypassTransfer
)
{
// src_buf.p_data_ = reinterpret_cast<float*>(dst_buf.p_data_) + src_offset;
if
constexpr
(
!
std
::
is_same
<
ElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
// if (true) {
const
ck
::
index_t
m_per_block
=
slice_length
[
Number
<
0
>
{}];
const
ck
::
index_t
n_per_block
=
slice_length
[
Number
<
1
>
{}];
const
ck
::
index_t
current_n
=
ck
::
math
::
min
(
DstGemmN
-
i_dst_gemm_n
,
n_per_block
);
float
*
p_dst
=
reinterpret_cast
<
float
*>
(
dst_buf
.
p_data_
)
+
dst_offset
;
ck
::
index_t
i_m_itr
=
m_per_block
;
// printf("xxxx %d, current_n:%d, DstGemmN:%d, n_per_block:%d,
// dst_offset:%d\n",__LINE__, current_n,
// DstGemmN, n_per_block, dst_offset);fflush(stdout);
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_dst
+
2
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_dst
+
3
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_dst
+
4
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_dst
+
5
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_dst
+
6
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_dst
+
7
*
DstGemmN
,
current_n
,
element_op_
);
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
}
if
(
i_m_itr
&
4
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_dst
+
2
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_dst
+
3
*
DstGemmN
,
current_n
,
element_op_
);
p_dst
+=
4
*
DstGemmN
;
}
if
(
i_m_itr
&
2
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_dst
+
1
*
DstGemmN
,
current_n
,
element_op_
);
p_dst
+=
2
*
DstGemmN
;
}
if
(
i_m_itr
&
1
)
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_dst
+
0
*
DstGemmN
,
current_n
,
element_op_
);
}
}
}
else
else
{
{
const
ck
::
index_t
m_per_block
=
const
ck
::
index_t
m_per_block
=
...
@@ -978,14 +1092,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -978,14 +1092,22 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
// standard 8-4-2-1 pattern
// standard 8-4-2-1 pattern
while
(
i_m_itr
>=
8
)
while
(
i_m_itr
>=
8
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
);
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
);
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
4
*
DstGemmN
,
p_src
+
4
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
5
*
DstGemmN
,
p_src
+
5
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
6
*
DstGemmN
,
p_src
+
6
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
7
*
DstGemmN
,
p_src
+
7
*
n_per_block
,
current_n
,
element_op_
);
i_m_itr
-=
8
;
i_m_itr
-=
8
;
p_dst
+=
8
*
DstGemmN
;
p_dst
+=
8
*
DstGemmN
;
...
@@ -994,10 +1116,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -994,10 +1116,14 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if
(
i_m_itr
&
4
)
if
(
i_m_itr
&
4
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
);
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
2
*
DstGemmN
,
p_src
+
2
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
3
*
DstGemmN
,
p_src
+
3
*
n_per_block
,
current_n
,
element_op_
);
p_dst
+=
4
*
DstGemmN
;
p_dst
+=
4
*
DstGemmN
;
p_src
+=
4
*
n_per_block
;
p_src
+=
4
*
n_per_block
;
...
@@ -1005,8 +1131,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -1005,8 +1131,10 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if
(
i_m_itr
&
2
)
if
(
i_m_itr
&
2
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
);
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
,
element_op_
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
1
*
DstGemmN
,
p_src
+
1
*
n_per_block
,
current_n
,
element_op_
);
p_dst
+=
2
*
DstGemmN
;
p_dst
+=
2
*
DstGemmN
;
p_src
+=
2
*
n_per_block
;
p_src
+=
2
*
n_per_block
;
...
@@ -1014,7 +1142,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
...
@@ -1014,7 +1142,8 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_MxN
if
(
i_m_itr
&
1
)
if
(
i_m_itr
&
1
)
{
{
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
);
avx2_util
::
memcpy32_avx2
(
p_dst
+
0
*
DstGemmN
,
p_src
+
0
*
n_per_block
,
current_n
,
element_op_
);
}
}
// printf("xxxx %d\n",__LINE__);fflush(stdout);
// printf("xxxx %d\n",__LINE__);fflush(stdout);
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
090ba885
...
@@ -20,6 +20,7 @@ using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; /
...
@@ -20,6 +20,7 @@ using WeiLayout = ck::tensor_layout::gemm::ColumnMajor; /
static
constexpr
bool
NonTemporalStore
=
false
;
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
WeiType
,
...
@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
...
@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
...
@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
...
@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
}
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
{});
}
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
...
profiler/include/profile_conv_fwd_cpu_impl.hpp
View file @
090ba885
...
@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
...
@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
090ba885
...
@@ -12,6 +12,11 @@
...
@@ -12,6 +12,11 @@
#include <omp.h>
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
...
@@ -22,6 +27,7 @@ namespace device {
...
@@ -22,6 +27,7 @@ namespace device {
namespace
device_conv2d_fwd_avx2_instance
{
namespace
device_conv2d_fwd_avx2_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
...
@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
...
@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
...
@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
#endif
template
<
typename
T
>
template
<
typename
T
>
static
bool
static
bool
...
@@ -296,8 +316,15 @@ int main(int argc, char* argv[])
...
@@ -296,8 +316,15 @@ int main(int argc, char* argv[])
}
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>
;
#endif
// add device Conv instances
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
...
@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
...
@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
{
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if
(
omp_get_max_threads
()
>
1
)
if
(
omp_get_max_threads
()
>
1
)
{
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
...
@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
...
@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c
(
conv_ptrs
);
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c
(
conv_ptrs
);
}
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
conv_ptrs
);
}
#endif
}
}
if
(
conv_ptrs
.
size
()
<=
0
)
if
(
conv_ptrs
.
size
()
<=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment