Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
090ba885
"vscode:/vscode.git/clone" did not exist on "881a6b58c3b5594d7f2ca1150b5a6779dceee808"
Commit
090ba885
authored
May 15, 2022
by
carlushuang
Browse files
add elementwise fusion support
parent
8ce9fe57
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
842 additions
and
510 deletions
+842
-510
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
...tion/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
+5
-0
include/ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp
...nsor_operation/cpu/element/element_wise_operation_cpu.hpp
+342
-370
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
+109
-14
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+253
-124
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
...2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
+76
-1
profiler/include/profile_conv_fwd_cpu_impl.hpp
profiler/include/profile_conv_fwd_cpu_impl.hpp
+9
-0
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
+48
-1
No files found.
include/ck/tensor_operation/cpu/device/device_convnd_fwd_avx2_nhwc_kyxc_nhwk.hpp
View file @
090ba885
...
...
@@ -896,6 +896,11 @@ struct DeviceConvNDFwdAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp
View file @
090ba885
#pragma once
#include "data_type_cpu.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
element_wise
{
using
float8_t
=
ck
::
cpu
::
float8_t
;
using
float4_t
=
ck
::
cpu
::
float4_t
;
struct
PassThrough
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
x
;
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
x
;
}
};
struct
Add
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
x0
+
x1
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
y
=
_mm_add_ps
(
x0
,
x1
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
y
=
_mm256_add_ps
(
x0
,
x1
);
}
};
struct
AlphaBetaAdd
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{}
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
alpha_
*
x0
+
beta_
*
x1
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
y
=
_mm_add_ps
(
_mm_mul_ps
(
x0
,
_mm_set1_ps
(
alpha_
)),
_mm_mul_ps
(
x1
,
_mm_set1_ps
(
beta_
)));
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
y
=
_mm256_add_ps
(
_mm256_mul_ps
(
x0
,
_mm256_set1_ps
(
alpha_
)),
_mm256_mul_ps
(
x1
,
_mm256_set1_ps
(
beta_
)));
}
float
alpha_
;
float
beta_
;
};
struct
AddRelu
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
const
float
a
=
x0
+
x1
;
y
=
a
>
0
?
a
:
0
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
y
=
_mm_max_ps
(
_mm_add_ps
(
x0
,
x1
),
_mm_setzero_ps
());
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
y
=
_mm256_max_ps
(
_mm256_add_ps
(
x0
,
x1
),
_mm256_setzero_ps
());
}
};
#if 0
struct AddHardswish
{
void operator()(float& y, const float& x0, const float& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
y = c;
}
};
#endif
struct
AddReluAdd
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
,
const
float
&
x2
)
const
{
float
a
=
x0
+
x1
;
float
b
=
a
>
0
?
a
:
0
;
float
c
=
b
+
x2
;
y
=
c
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
,
const
float4_t
&
x2
)
const
{
float4_t
a
=
_mm_add_ps
(
x0
,
x1
);
float4_t
b
=
_mm_max_ps
(
a
,
_mm_setzero_ps
());
y
=
_mm_add_ps
(
b
,
x2
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
,
const
float8_t
&
x2
)
const
{
float8_t
a
=
_mm256_add_ps
(
x0
,
x1
);
float8_t
b
=
_mm256_max_ps
(
a
,
_mm256_setzero_ps
());
y
=
_mm256_add_ps
(
b
,
x2
);
}
};
#if 0
struct AddHardswishAdd
{
void
operator()(float& y, const float& x0, const float& x1, const float& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
void
operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
{
float a = x0 + x1;
float b = a + float{3};
float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
float d = c + x2;
y = d;
}
};
#endif
#if 0
struct RequantReluRequant
{
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu)
{
}
void operator()(int8_t& y, const int& x) const
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
float scaleRelu_;
};
#endif
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnaryIdentic
;
template
<
>
struct
UnaryIdentic
<
float
,
float
,
false
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
float
,
float
,
true
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
/
type_convert
<
float
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
float4_t
,
float4_t
,
false
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
float4_t
,
float4_t
,
true
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_div_ps
(
x
,
_mm_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
float8_t
,
float8_t
,
false
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
float8_t
,
float8_t
,
true
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_div_ps
(
x
,
_mm256_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnarySquare
;
template
<
>
struct
UnarySquare
<
float
,
float
,
false
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
*
x
;
};
};
template
<
>
struct
UnarySquare
<
float
,
float
,
true
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
*
x
/
type_convert
<
float
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnarySquare
<
float4_t
,
float4_t
,
false
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_mul_ps
(
x
,
x
);
};
};
template
<
>
struct
UnarySquare
<
float4_t
,
float4_t
,
true
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_div_ps
(
_mm_mul_ps
(
x
,
x
),
_mm_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnarySquare
<
float8_t
,
float8_t
,
false
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_mul_ps
(
x
,
x
);
};
};
template
<
>
struct
UnarySquare
<
float8_t
,
float8_t
,
true
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_div_ps
(
_mm256_mul_ps
(
x
,
x
),
_mm256_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
typename
Y
,
typename
X
>
struct
UnaryAbs
;
template
<
>
struct
UnaryAbs
<
float
,
float
>
{
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
abs
(
x
);
};
};
template
<
>
struct
UnaryAbs
<
float4_t
,
float4_t
>
{
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
__m128
Mask
=
_mm_castsi128_ps
(
_mm_set1_epi32
(
~
0x80000000
));
y
=
_mm_and_ps
(
Mask
,
x
);
};
};
template
<
>
struct
UnaryAbs
<
float8_t
,
float8_t
>
{
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
__m256
Mask
=
_mm256_castsi256_ps
(
_mm256_set1_epi32
(
~
0x80000000
));
y
=
_mm256_and_ps
(
Mask
,
x
);
};
};
template
<
typename
Y
,
typename
X
>
struct
UnarySqrt
;
template
<
>
struct
UnarySqrt
<
float
,
float
>
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
sqrtf
(
x
);
};
};
template
<
>
struct
UnarySqrt
<
float4_t
,
float4_t
>
{
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_sqrt_ps
(
x
);
};
};
template
<
>
struct
UnarySqrt
<
float8_t
,
float8_t
>
{
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_sqrt_ps
(
x
);
};
};
}
// namespace element_wise
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
#pragma once
#include "data_type_cpu.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
element_wise
{
using
float8_t
=
ck
::
cpu
::
float8_t
;
using
float4_t
=
ck
::
cpu
::
float4_t
;
struct
PassThrough
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
Apply
(
x
);
}
float
Apply
(
const
float
&
x
)
const
{
return
x
;
}
float4_t
Apply
(
const
float4_t
&
x
)
const
{
return
x
;
}
float8_t
Apply
(
const
float8_t
&
x
)
const
{
return
x
;
}
static
constexpr
char
*
Name
()
{
return
"PassThrough"
;
}
};
struct
Add
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
float
Apply
(
const
float
&
x0
,
const
float
&
x1
)
const
{
return
x0
+
x1
;
}
float4_t
Apply
(
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
return
_mm_add_ps
(
x0
,
x1
);
}
float8_t
Apply
(
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
return
_mm256_add_ps
(
x0
,
x1
);
}
static
constexpr
char
*
Name
()
{
return
"Add"
;
}
};
struct
Relu
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
Apply
(
x
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
Apply
(
x
);
}
float
Apply
(
const
float
&
x
)
const
{
return
x
>
0
?
x
:
0
;
}
float4_t
Apply
(
const
float4_t
&
x
)
const
{
return
_mm_max_ps
(
x
,
_mm_setzero_ps
());
}
float8_t
Apply
(
const
float8_t
&
x
)
const
{
return
_mm256_max_ps
(
x
,
_mm256_setzero_ps
());
}
static
constexpr
char
*
Name
()
{
return
"Relu"
;
}
};
struct
AlphaBetaAdd
{
AlphaBetaAdd
(
float
alpha
,
float
beta
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{}
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
float
Apply
(
const
float
&
x0
,
const
float
&
x1
)
const
{
return
alpha_
*
x0
+
beta_
*
x1
;
}
float4_t
Apply
(
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
return
_mm_add_ps
(
_mm_mul_ps
(
x0
,
_mm_set1_ps
(
alpha_
)),
_mm_mul_ps
(
x1
,
_mm_set1_ps
(
beta_
)));
}
float8_t
Apply
(
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
return
_mm256_add_ps
(
_mm256_mul_ps
(
x0
,
_mm256_set1_ps
(
alpha_
)),
_mm256_mul_ps
(
x1
,
_mm256_set1_ps
(
beta_
)));
}
static
constexpr
char
*
Name
()
{
return
"AlphaBetaAdd"
;
}
float
alpha_
;
float
beta_
;
};
struct
AddRelu
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
y
=
Apply
(
x0
,
x1
);
}
float
Apply
(
const
float
&
x0
,
const
float
&
x1
)
const
{
const
float
a
=
x0
+
x1
;
return
a
>
0
?
a
:
0
;
}
float4_t
Apply
(
const
float4_t
&
x0
,
const
float4_t
&
x1
)
const
{
return
_mm_max_ps
(
_mm_add_ps
(
x0
,
x1
),
_mm_setzero_ps
());
}
float8_t
Apply
(
const
float8_t
&
x0
,
const
float8_t
&
x1
)
const
{
return
_mm256_max_ps
(
_mm256_add_ps
(
x0
,
x1
),
_mm256_setzero_ps
());
}
static
constexpr
char
*
Name
()
{
return
"AddRelu"
;
}
};
struct
AddReluAdd
{
void
operator
()(
float
&
y
,
const
float
&
x0
,
const
float
&
x1
,
const
float
&
x2
)
const
{
float
a
=
x0
+
x1
;
float
b
=
a
>
0
?
a
:
0
;
float
c
=
b
+
x2
;
y
=
c
;
}
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x0
,
const
float4_t
&
x1
,
const
float4_t
&
x2
)
const
{
float4_t
a
=
_mm_add_ps
(
x0
,
x1
);
float4_t
b
=
_mm_max_ps
(
a
,
_mm_setzero_ps
());
y
=
_mm_add_ps
(
b
,
x2
);
}
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x0
,
const
float8_t
&
x1
,
const
float8_t
&
x2
)
const
{
float8_t
a
=
_mm256_add_ps
(
x0
,
x1
);
float8_t
b
=
_mm256_max_ps
(
a
,
_mm256_setzero_ps
());
y
=
_mm256_add_ps
(
b
,
x2
);
}
static
constexpr
char
*
Name
()
{
return
"AddReluAdd"
;
}
};
// Unary operators are usually called element-wisely before/after the reduction is executed on the
// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnaryIdentic
;
template
<
>
struct
UnaryIdentic
<
float
,
float
,
false
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
float
,
float
,
true
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
/
type_convert
<
float
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
float4_t
,
float4_t
,
false
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
float4_t
,
float4_t
,
true
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_div_ps
(
x
,
_mm_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnaryIdentic
<
float8_t
,
float8_t
,
false
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
x
;
};
};
template
<
>
struct
UnaryIdentic
<
float8_t
,
float8_t
,
true
>
{
UnaryIdentic
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_div_ps
(
x
,
_mm256_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
typename
Y
,
typename
X
,
bool
HasDividing
=
false
>
struct
UnarySquare
;
template
<
>
struct
UnarySquare
<
float
,
float
,
false
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
*
x
;
};
};
template
<
>
struct
UnarySquare
<
float
,
float
,
true
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
*
x
/
type_convert
<
float
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnarySquare
<
float4_t
,
float4_t
,
false
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_mul_ps
(
x
,
x
);
};
};
template
<
>
struct
UnarySquare
<
float4_t
,
float4_t
,
true
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_div_ps
(
_mm_mul_ps
(
x
,
x
),
_mm_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
>
struct
UnarySquare
<
float8_t
,
float8_t
,
false
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_mul_ps
(
x
,
x
);
};
};
template
<
>
struct
UnarySquare
<
float8_t
,
float8_t
,
true
>
{
UnarySquare
(
const
int32_t
divider
=
1
)
{
divider_
=
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_div_ps
(
_mm256_mul_ps
(
x
,
x
),
_mm256_set1_ps
(
static_cast
<
float
>
(
divider_
)));
};
int32_t
divider_
=
1
;
};
template
<
typename
Y
,
typename
X
>
struct
UnaryAbs
;
template
<
>
struct
UnaryAbs
<
float
,
float
>
{
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
abs
(
x
);
};
};
template
<
>
struct
UnaryAbs
<
float4_t
,
float4_t
>
{
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
__m128
Mask
=
_mm_castsi128_ps
(
_mm_set1_epi32
(
~
0x80000000
));
y
=
_mm_and_ps
(
Mask
,
x
);
};
};
template
<
>
struct
UnaryAbs
<
float8_t
,
float8_t
>
{
UnaryAbs
(
const
int32_t
divider
=
1
)
{
(
void
)
divider
;
};
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
__m256
Mask
=
_mm256_castsi256_ps
(
_mm256_set1_epi32
(
~
0x80000000
));
y
=
_mm256_and_ps
(
Mask
,
x
);
};
};
template
<
typename
Y
,
typename
X
>
struct
UnarySqrt
;
template
<
>
struct
UnarySqrt
<
float
,
float
>
{
void
operator
()(
float
&
y
,
const
float
&
x
)
const
{
y
=
sqrtf
(
x
);
};
};
template
<
>
struct
UnarySqrt
<
float4_t
,
float4_t
>
{
void
operator
()(
float4_t
&
y
,
const
float4_t
&
x
)
const
{
y
=
_mm_sqrt_ps
(
x
);
};
};
template
<
>
struct
UnarySqrt
<
float8_t
,
float8_t
>
{
void
operator
()(
float8_t
&
y
,
const
float8_t
&
x
)
const
{
y
=
_mm256_sqrt_ps
(
x
);
};
};
}
// namespace element_wise
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/cpu/grid/gridwise_gemm_avx2.hpp
View file @
090ba885
...
...
@@ -128,6 +128,51 @@ struct GridwiseGemmAvx2_MxN
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
static
auto
GetAMultiIndex
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
m_per_blk
,
k_per_blk
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
k_per_blk
,
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
));
}
}
static
auto
GetBMultiIndex
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
k_per_blk
,
math
::
integer_least_multiple
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
// B : N/8, K, N8
return
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
}
static
auto
GetCMultiIndex
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
}
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
...
...
@@ -300,14 +345,18 @@ struct GridwiseGemmAvx2_MxN
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
:
c_grid_desc
;
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
//
c_threadwise_copy.SetDstSliceOrigin(c_grid_desc,
//
ck::make_multi_index(i_mc, i_nc));
}
else
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
RunRead
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
for
(
ck
::
index_t
i_kc
=
0
;
i_kc
<
GemmK
;
i_kc
+=
k_per_block
)
...
...
@@ -317,8 +366,16 @@ struct GridwiseGemmAvx2_MxN
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
,
GetAMultiIndex
(
mc_size
,
kc_size
));
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
,
GetBMultiIndex
(
kc_size
,
nc_size
));
blockwise_gemm
.
Run
(
a_block_desc
,
a_block_buf
,
...
...
@@ -338,8 +395,14 @@ struct GridwiseGemmAvx2_MxN
}
}
if
constexpr
(
UseCLocalBuffer
)
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
// if constexpr(UseCLocalBuffer)
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
}
}
...
...
@@ -415,7 +478,11 @@ struct GridwiseGemmAvx2_MxN
ck
::
index_t
kc_size
=
ck
::
math
::
min
(
GemmK
-
i_kc
,
k_per_block
);
auto
a_block_desc
=
GetABlockDescriptor
(
mc_size
,
kc_size
);
a_threadwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
a_threadwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
,
GetAMultiIndex
(
mc_size
,
kc_size
));
b_threadwise_copy
.
SetSrcSliceOrigin
(
b_grid_desc
,
ck
::
make_multi_index
(
0
,
i_kc
,
0
));
...
...
@@ -429,8 +496,11 @@ struct GridwiseGemmAvx2_MxN
nc_size
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
auto
b_block_desc
=
GetBBlockDescriptor
(
kc_size
,
nc_size
);
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
b_threadwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
,
GetBMultiIndex
(
kc_size
,
nc_size
));
auto
c_block_desc
=
UseCLocalBuffer
?
GetCBlockDescriptor
(
mc_size
,
nc_size
)
...
...
@@ -440,8 +510,11 @@ struct GridwiseGemmAvx2_MxN
{
c_threadwise_copy
.
SetSrcSliceOrigin
(
c_block_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
RunRead
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
blockwise_gemm
.
Run
(
a_block_desc
,
...
...
@@ -456,14 +529,36 @@ struct GridwiseGemmAvx2_MxN
i_kc
!=
0
);
if
((
i_nc
+
n_per_block
)
<
GemmN
)
{
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_move_k_step
);
}
if
constexpr
(
UseCLocalBuffer
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
Run
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
);
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
else
{
// only write for last K, since the RunWrite here is just doing
// elementwise op from global to global
if
((
i_kc
+
k_per_block
)
>=
GemmK
)
{
c_threadwise_copy
.
SetDstSliceOrigin
(
c_grid_desc
,
ck
::
make_multi_index
(
i_mc
,
i_nc
));
c_threadwise_copy
.
RunWrite
(
c_block_desc
,
c_block_buf
,
c_grid_desc
,
c_grid_buf
,
GetCMultiIndex
(
mc_size
,
nc_size
));
}
}
}
...
...
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
090ba885
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/cpu/conv2d_fwd/device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_instance.cpp
View file @
090ba885
...
...
@@ -19,7 +19,8 @@ using InLayout = ck::tensor_layout::gemm::RowMajor; /
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXC
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
using
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
=
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InType
,
WeiType
,
...
...
@@ -110,6 +111,59 @@ using device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances = std::tuple<
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
PT
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
false
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
false
)
>
;
// clang-format on
// use this in single thread, but gemm_n is not multiple of 8
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
64
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
// use this in multi thread environment (need local C buffer to avoid cache coherence, although some
// time no local c is better...)
using
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
=
std
::
tuple
<
// clang-format off
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
48
,
24
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
16
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
72
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
96
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
32
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
120
,
64
,
128
,
6
,
16
,
true
,
true
,
true
),
// DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32(PT, PT, PT, 256, 128, 64, 6, 16, true, true, true),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
256
,
128
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
128
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
240
,
128
,
4
,
24
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
512
,
256
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
768
,
320
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
896
,
352
,
128
,
6
,
16
,
true
,
true
,
true
),
DEVICE_CONV2D_FWD_AVX2_NHWC_KYXC_NHWK_F32
(
PT
,
PT
,
Relu
,
1024
,
416
,
128
,
6
,
16
,
true
,
true
,
true
)
>
;
// clang-format on
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
PT
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
...
...
@@ -130,6 +184,27 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_local_c_relu_instances
{});
}
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PT
,
PT
,
Relu
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
add_device_operation_instances
(
instances
,
device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_f32_mt_relu_instances
{});
}
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
...
...
profiler/include/profile_conv_fwd_cpu_impl.hpp
View file @
090ba885
...
...
@@ -24,6 +24,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
...
...
test/convnd_fwd_cpu/conv2d_fwd_cpu.cpp
View file @
090ba885
...
...
@@ -12,6 +12,11 @@
#include <omp.h>
#define AVX2_DATA_ALIGNMENT 32
#define TEST_FUSION_PASSTHROUGH 0
#define TEST_FUSION_RELU 1
#define TEST_FUSION TEST_FUSION_RELU
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
...
...
@@ -22,6 +27,7 @@ namespace device {
namespace
device_conv2d_fwd_avx2_instance
{
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
...
...
@@ -32,6 +38,15 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c(
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
void
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
std
::
vector
<
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>>&
instances
);
}
// namespace device_conv2d_fwd_avx2_instance
}
// namespace device
}
// namespace cpu
...
...
@@ -40,7 +55,12 @@ void add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt(
using
InElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using
OutElementOp
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
#endif
template
<
typename
T
>
static
bool
...
...
@@ -295,9 +315,16 @@ int main(int argc, char* argv[])
ref_invoker
.
Run
(
ref_argument
);
}
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Relu
;
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
#endif
#if TEST_FUSION == TEST_FUSION_RELU
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
Relu
>
;
#endif
// add device Conv instances
std
::
vector
<
DeviceConvFwdNoOpPtr
>
conv_ptrs
;
...
...
@@ -306,6 +333,7 @@ int main(int argc, char* argv[])
ck
::
is_same_v
<
ck
::
remove_cv_t
<
WeiDataType
>
,
float
>
&&
ck
::
is_same_v
<
ck
::
remove_cv_t
<
OutDataType
>
,
float
>
)
{
#if TEST_FUSION == TEST_FUSION_PASSTHROUGH
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
...
...
@@ -322,6 +350,25 @@ int main(int argc, char* argv[])
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c
(
conv_ptrs
);
}
#endif
#if TEST_FUSION == TEST_FUSION_RELU
if
(
omp_get_max_threads
()
>
1
)
{
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_mt_relu
(
conv_ptrs
);
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
conv_ptrs
);
}
else
{
if
(
K
%
8
==
0
)
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_relu
(
conv_ptrs
);
else
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_avx2_instance
::
add_device_conv2d_fwd_avx2_nhwc_kyxc_nhwk_local_c_relu
(
conv_ptrs
);
}
#endif
}
if
(
conv_ptrs
.
size
()
<=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment