Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dd6a8de4
Commit
dd6a8de4
authored
Apr 06, 2022
by
Jehandad Khan
Browse files
Merge branch 'develop' into jd/dev_pkg
parents
0aa899aa
abf4bdb9
Changes
470
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
425 additions
and
151 deletions
+425
-151
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp
...y/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp
...y/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp
...y/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp
...y/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp
...obselete_driver_offline/driver_contraction_dlops_v1r2.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+2
-2
library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+2
-2
library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+2
-2
library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp
...ibrary/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp
...ibrary/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp
...brary/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp
+1
-1
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp
...brary/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
...e_tensor_operation/cpu/reference_conv_backward_weight.hpp
+3
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+194
-59
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+61
-6
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp
..._operation_instance/gpu/reduce/device_reduce_instance.hpp
+13
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+42
-40
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
...u/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
+60
-0
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp
...u/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp
+25
-19
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp
...u/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp
+12
-9
No files found.
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_mn.hpp
View file @
dd6a8de4
...
@@ -398,7 +398,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
...
@@ -398,7 +398,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
ABType
,
ABType
,
AccType
,
AccType
,
CType
,
CType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
decltype
(
c_m_n_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_kn_nm.hpp
View file @
dd6a8de4
...
@@ -230,7 +230,7 @@ void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
...
@@ -230,7 +230,7 @@ void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
ABType
,
ABType
,
AccType
,
AccType
,
CType
,
CType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
decltype
(
c_m_n_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_mn.hpp
View file @
dd6a8de4
...
@@ -499,7 +499,7 @@ void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
...
@@ -499,7 +499,7 @@ void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
ABType
,
ABType
,
AccType
,
AccType
,
CType
,
CType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
decltype
(
c_m_n_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/device_gemm_xdlops_mk_nk_nm.hpp
View file @
dd6a8de4
...
@@ -286,7 +286,7 @@ void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
...
@@ -286,7 +286,7 @@ void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
ABType
,
ABType
,
AccType
,
AccType
,
CType
,
CType
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
decltype
(
c_m_n_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/driver_contraction_dlops_v1r2.hpp
View file @
dd6a8de4
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_GK0_GM0_GM1_GK1
,
typename
AGridDesc_GK0_GM0_GM1_GK1
,
typename
BGridDesc_GK0_GN0_GN1_GK1
,
typename
BGridDesc_GK0_GN0_GN1_GK1
,
typename
CGridDesc_GM0_GM1_GN0_GN1
,
typename
CGridDesc_GM0_GM1_GN0_GN1
,
...
...
library/include/ck/library/obselete_driver_offline/driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
dd6a8de4
...
@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
...
@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
ActivTypeEnum
_t
activ_type
>
ck
::
ActivTypeEnum
activ_type
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
{
{
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
...
@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/driver_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
dd6a8de4
...
@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
...
@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
ActivTypeEnum
_t
activ_type
>
ck
::
ActivTypeEnum
activ_type
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
{
{
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
...
@@ -260,7 +260,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -260,7 +260,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
View file @
dd6a8de4
...
@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
...
@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
,
ck
::
ActivTypeEnum
_t
activ_type
>
ck
::
ActivTypeEnum
activ_type
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
{
{
template
<
typename
...
Wei
,
template
<
typename
...
Wei
,
...
@@ -305,7 +305,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
...
@@ -305,7 +305,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum
_t
::
Set
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
a_e0_e1_k_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_e2_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
...
...
library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r2.hpp
View file @
dd6a8de4
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AKMGridDesc
,
typename
AKMGridDesc
,
typename
BKNGridDesc
,
typename
BKNGridDesc
,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
...
...
library/include/ck/library/obselete_driver_offline/driver_gemm_dlops_v1r3.hpp
View file @
dd6a8de4
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AK0MK1GridDesc
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
...
...
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r3.hpp
View file @
dd6a8de4
...
@@ -11,7 +11,7 @@ template <ck::index_t BlockSize,
...
@@ -11,7 +11,7 @@ template <ck::index_t BlockSize,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K
,
typename
BGridDesc_K0_N_K
,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
...
...
library/include/ck/library/obselete_driver_offline/driver_gemm_xdlops_v2r4.hpp
View file @
dd6a8de4
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
...
@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
InMemoryDataOperationEnum
_t
CGlobalMemoryDataOperation
,
ck
::
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
ABK0MK1GridDesc
,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
View file @
dd6a8de4
...
@@ -17,7 +17,7 @@ template <typename InDataType,
...
@@ -17,7 +17,7 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
>
struct
ReferenceConv
Wrw
:
public
device
::
BaseOperator
struct
ReferenceConv
BwdWeight
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
...
@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
...
@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
// Invoker
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
struct
Invoker
:
public
device
::
BaseInvoker
{
{
using
Argument
=
ReferenceConv
Wrw
::
Argument
;
using
Argument
=
ReferenceConv
BwdWeight
::
Argument
;
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
...
@@ -163,7 +163,7 @@ struct ReferenceConvWrw : public device::BaseOperator
...
@@ -163,7 +163,7 @@ struct ReferenceConvWrw : public device::BaseOperator
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"ReferenceConv
F
wd"
str
<<
"ReferenceConv
B
wd
Weight
"
<<
std
::
endl
;
<<
std
::
endl
;
// clang-format on
// clang-format on
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
dd6a8de4
...
@@ -14,17 +14,20 @@ namespace host {
...
@@ -14,17 +14,20 @@ namespace host {
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
Tensor
<
InDataType
>&
in
_n_c_hi_wi
,
Argument
(
Tensor
<
InDataType
>&
in
put
,
const
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
,
const
Tensor
<
WeiDataType
>&
wei
ght
,
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
,
const
Tensor
<
OutDataType
>&
out
put
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
...
@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
:
in
_n_c_hi_wi_
{
in_n_c_hi_wi
},
:
in
put_
{
input
},
wei
_k_c_y_x_
{
wei_k_c_y_x
},
wei
ght_
{
weight
},
out
_n_k_ho_wo_
{
out_n_k_ho_wo
},
out
put_
{
output
},
conv_strides_
{
conv_filter_strides
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
in_left_pads_
{
input_left_pads
},
...
@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
{
}
}
Tensor
<
InDataType
>&
in
_n_c_hi_wi
_
;
Tensor
<
InDataType
>&
in
put
_
;
const
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
_
;
const
Tensor
<
WeiDataType
>&
wei
ght
_
;
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
_
;
const
Tensor
<
OutDataType
>&
out
put
_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
conv_dilations_
;
...
@@ -66,67 +69,199 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -66,67 +69,199 @@ struct ReferenceConvBwdData : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
if
constexpr
(
NumDimSpatial
==
1
)
std
::
size_t
K
=
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
0
];
{
std
::
size_t
Y
=
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
];
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
X
=
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
2
];
AccDataType
v_acc
=
0
;
std
::
size_t
Wo
=
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
3
];
float
v_acc
=
0
;
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
0
]
-
x
*
arg
.
conv_dilations_
[
0
];
if
(
w_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
0
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
}
}
}
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
float
v_in
;
{
arg
.
in_element_op_
(
v_in
,
v_acc
);
int
h_tmp
=
hi
+
arg
.
in_left_pads_
[
0
]
-
y
*
arg
.
conv_dilations_
[
0
];
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
if
(
h_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
2
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
{
int
h
o
=
h_tmp
/
arg
.
conv_
stride
s_
[
0
];
int
h
_tmp
=
hi
+
arg
.
in_left_pads_
[
0
]
-
y
*
arg
.
conv_
dilation
s_
[
0
];
if
(
h
o
>=
0
&&
ho
<
Ho
)
if
(
h
_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
int
ho
=
h_tmp
/
arg
.
conv_strides_
[
0
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
1
]
-
x
*
arg
.
conv_dilations_
[
1
];
for
(
int
x
=
0
;
x
<
X
;
++
x
)
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
1
];
int
w_tmp
=
if
(
wo
>=
0
&&
wo
<
Wo
)
wi
+
arg
.
in_left_pads_
[
1
]
-
x
*
arg
.
conv_dilations_
[
1
];
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
1
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
{
float
v_out
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
float
v_wei
=
0
;
{
AccDataType
v_out
=
0
;
arg
.
out_element_op_
(
AccDataType
v_wei
=
0
;
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
out_element_op_
(
v_out
,
arg
.
out_n_k_ho_wo_
(
n
,
k
,
ho
,
wo
)));
ck
::
type_convert
<
AccDataType
>
(
arg
.
wei_element_op_
(
v_wei
,
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
ck
::
type_convert
<
float
>
(
arg
.
wei_element_op_
(
v_wei
,
arg
.
wei_k_c_y_x_
(
k
,
c
,
y
,
x
)));
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
v_acc
+=
v_out
*
v_wei
;
}
}
}
}
}
}
}
AccDataType
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
3
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Z
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
std
::
size_t
Do
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
for
(
int
z
=
0
;
z
<
Z
;
++
z
)
{
int
d_tmp
=
di
+
arg
.
in_left_pads_
[
0
]
-
z
*
arg
.
conv_dilations_
[
0
];
if
(
d_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
int
do_
=
d_tmp
/
arg
.
conv_strides_
[
0
];
if
(
do_
>=
0
&&
do_
<
Do
)
{
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
arg
.
in_left_pads_
[
1
]
-
y
*
arg
.
conv_dilations_
[
1
];
if
(
h_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
int
ho
=
h_tmp
/
arg
.
conv_strides_
[
1
];
if
(
ho
>=
0
&&
ho
<
Ho
)
{
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
2
]
-
x
*
arg
.
conv_dilations_
[
2
];
if
(
w_tmp
%
arg
.
conv_strides_
[
2
]
==
0
)
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
2
];
if
(
wo
>=
0
&&
wo
<
Wo
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
float
v_in
;
AccDataType
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
in
_n_c_hi_wi
_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
in
put
_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
input_
.
mDesc
.
GetLengths
()[
3
],
std
::
thread
::
hardware_concurrency
());
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
,
hipStream_t
,
bool
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
,
hipStream_t
,
bool
)
override
...
@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
in
_n_c_hi_wi
,
static
auto
MakeArgument
(
Tensor
<
InDataType
>&
in
put
,
const
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
,
const
Tensor
<
WeiDataType
>&
wei
ght
,
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
,
const
Tensor
<
OutDataType
>&
out
put
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
...
@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
{
{
return
Argument
{
in
_n_c_hi_wi
,
return
Argument
{
in
put
,
wei
_k_c_y_x
,
wei
ght
,
out
_n_k_ho_wo
,
out
put
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
dd6a8de4
...
@@ -14,9 +14,9 @@ namespace host {
...
@@ -14,9 +14,9 @@ namespace host {
//
//
// @brief Reference implementation for forward convolution.
// @brief Reference implementation for forward convolution.
//
//
// @paragraph Support
ed tensor layouts. Input tensor supports NCHiWi data layout.
// @paragraph
Support
s both NCHW as well as NHWC formats (and their respective
//
Weights tensor supports KCYX data layout. Output tensor supports
//
counterparts for weight and output) as long as tensor descriptor
//
NKHoWo data layout
.
//
lengths is in NCHW
.
//
//
// @tparam InDataType Input tensor data type.
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
// @tparam WeiDataType Weights tensor data type.
...
@@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_wei
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
arg
.
in_element_op_
(
v_in
,
static_cast
<
const
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
arg
.
wei_element_op_
(
v_wei
,
static_cast
<
const
float
>
(
arg
.
weight_
(
k
,
c
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
x
)));
v_acc
+=
v_in
*
v_wei
;
v_acc
+=
v_in
*
v_wei
;
}
}
...
@@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
wo
)
=
v_out
;
arg
.
output_
(
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
)
;
};
};
make_ParallelTensorFunctor
(
f_ncw
,
make_ParallelTensorFunctor
(
f_ncw
,
...
@@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator
return
0
;
return
0
;
}
}
else
if
constexpr
(
NumDimSpatial
==
3
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
int
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
for
(
int
z
=
0
;
z
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
z
)
{
int
di
=
d_o
*
arg
.
conv_strides_
[
0
]
+
z
*
arg
.
conv_dilations_
[
0
]
-
arg
.
in_left_pads_
[
0
];
for
(
int
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
y
)
{
int
hi
=
ho
*
arg
.
conv_strides_
[
1
]
+
y
*
arg
.
conv_dilations_
[
1
]
-
arg
.
in_left_pads_
[
1
];
for
(
int
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
++
x
)
{
int
wi
=
wo
*
arg
.
conv_strides_
[
2
]
+
x
*
arg
.
conv_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
if
(
di
>=
0
&&
di
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
hi
>=
0
&&
hi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
}
}
}
}
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
],
arg
.
output_
.
mDesc
.
GetLengths
()[
3
],
arg
.
output_
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
,
hipStream_t
,
bool
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
,
hipStream_t
,
bool
)
override
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp
View file @
dd6a8de4
...
@@ -6,23 +6,36 @@
...
@@ -6,23 +6,36 @@
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
#include "device_reduce_instance_threadwise_i8_i8_i8.hpp"
#include "device_reduce_instance_threadwise_i8_i32_i8.hpp"
#include "device_reduce_instance_threadwise_b16_f32_b16.hpp"
#endif
#endif
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
dd6a8de4
...
@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
...
@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
ReductionConfiguration_2
<
0
,
2
,
2
,
2
,
1
>
,
ReductionConfiguration_2
<
0
,
2
,
2
,
2
,
1
>
,
ReductionConfiguration_2
<
0
,
1
,
1
,
2
,
1
>
,
ReductionConfiguration_2
<
0
,
1
,
1
,
2
,
1
>
,
ReductionConfiguration_2
<
1
,
2
,
1
,
1
,
2
>
,
ReductionConfiguration_2
<
1
,
2
,
1
,
1
,
2
>
,
ReductionConfiguration_2
<
1
,
2
,
2
,
1
,
2
>
,
ReductionConfiguration_2
<
0
,
1
,
1
,
3
,
1
>
,
ReductionConfiguration_2
<
0
,
1
,
1
,
3
,
1
>
,
ReductionConfiguration_2
<
1
,
1
,
1
,
1
,
3
>
ReductionConfiguration_2
<
1
,
1
,
1
,
1
,
3
>
// clang-format on
// clang-format on
...
@@ -48,7 +47,7 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
...
@@ -48,7 +47,7 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>
;
>
;
#endif
#endif
template
<
typename
AccDataType
,
ReduceTensorOp
_t
ReduceOpId
>
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
using
deviceReduceBlockWisePtrType
=
DeviceReducePtr
<
using
deviceReduceBlockWisePtrType
=
DeviceReducePtr
<
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
...
@@ -57,10 +56,10 @@ template <typename InDataType,
...
@@ -57,10 +56,10 @@ template <typename InDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
OutDataType
,
typename
OutDataType
,
int
Rank
,
int
Rank
,
typename
ReduceDim
s
,
int
Num
ReduceDim
,
ReduceTensorOp
_t
ReduceOpId
,
ReduceTensorOp
ReduceOpId
,
NanPropagation
_t
NanOpt
,
NanPropagation
NanOpt
,
ReduceTensorIndices
_t
IndicesOpt
>
ReduceTensorIndices
IndicesOpt
>
void
add_device_reduce_instance_blockwise
(
void
add_device_reduce_instance_blockwise
(
std
::
vector
<
deviceReduceBlockWisePtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
std
::
vector
<
deviceReduceBlockWisePtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
{
{
...
@@ -72,11 +71,11 @@ void add_device_reduce_instance_blockwise(
...
@@ -72,11 +71,11 @@ void add_device_reduce_instance_blockwise(
AccElementwiseOperation
;
AccElementwiseOperation
;
constexpr
bool
Indexable
=
constexpr
bool
Indexable
=
(
ReduceOpId
==
ReduceTensorOp
_t
::
MIN
||
ReduceOpId
==
ReduceTensorOp
_t
::
MAX
||
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
_t
::
AMAX
);
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
constexpr
bool
NeedIndices
=
Indexable
&&
(
IndicesOpt
!=
ReduceTensorIndices
_t
::
NO_INDICES
);
constexpr
bool
NeedIndices
=
Indexable
&&
(
IndicesOpt
!=
ReduceTensorIndices
::
NO_INDICES
);
constexpr
bool
PropagateNan
=
(
NanOpt
==
NanPropagation
_t
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
constexpr
bool
PropagateNan
=
(
NanOpt
==
NanPropagation
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances
>::
value
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances
>::
value
,
1
>
{}([
&
](
auto
i
)
{
using
cfg1
=
using
cfg1
=
...
@@ -91,7 +90,7 @@ void add_device_reduce_instance_blockwise(
...
@@ -91,7 +90,7 @@ void add_device_reduce_instance_blockwise(
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
Rank
,
Rank
,
ReduceDim
s
,
Num
ReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
...
@@ -112,34 +111,36 @@ void add_device_reduce_instance_blockwise(
...
@@ -112,34 +111,36 @@ void add_device_reduce_instance_blockwise(
});
});
};
};
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
#define ADD_BLOCKWISE_INST_BY_TYPE( \
template void add_device_reduce_instance_blockwise<inT, \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \
template void add_device_reduce_instance_blockwise<inT, \
outT, \
compT, \
Rank, \
outT, \
Sequence<__VA_ARGS__>, \
Rank, \
ReduceOpId, \
NumReduceDim, \
NanOpt, \
ReduceOpId, \
IndicesOpt>( \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
#define ADD_BLOCKWISE_INST_BY_ID( \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
outT, \
compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
outT, \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
static_cast<NanPropagation>(NanOpt), \
Rank, \
static_cast<ReduceTensorIndices>(IndicesOpt), \
__VA_ARGS__)
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
...)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank,
NumReduceDim)
\
extern template void add_device_reduce_instance_blockwise<inT, \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
compT, \
outT, \
outT, \
Rank, \
Rank, \
Sequence<__VA_ARGS__>,
\
NumReduceDim,
\
ReduceOpId, \
ReduceOpId, \
NanOpt, \
NanOpt, \
IndicesOpt>( \
IndicesOpt>( \
...
@@ -149,15 +150,16 @@ void add_device_reduce_instance_blockwise(
...
@@ -149,15 +150,16 @@ void add_device_reduce_instance_blockwise(
AccElementwiseOperation>> & \
AccElementwiseOperation>> & \
device_op_instances)
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
outT, \
compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \
outT, \
static_cast<NanPropagation_t>(NanOpt), \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \
static_cast<NanPropagation>(NanOpt), \
Rank, \
static_cast<ReduceTensorIndices>(IndicesOpt), \
__VA_ARGS__)
Rank, \
NumReduceDim)
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
}
// namespace device
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp
0 → 100644
View file @
dd6a8de4
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_reduce_instance
{
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
7
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
2
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
3
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
bhalf_t
,
float
,
bhalf_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp
View file @
dd6a8de4
...
@@ -11,25 +11,31 @@ namespace device {
...
@@ -11,25 +11,31 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
,
1
,
2
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
,
1
,
2
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
3
);
// for MIN
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
,
1
,
2
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
2
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
3
);
// for MAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
3
,
0
,
1
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
3
);
// for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
half_t
,
half_t
,
4
,
0
,
1
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp
View file @
dd6a8de4
...
@@ -11,16 +11,19 @@ namespace device {
...
@@ -11,16 +11,19 @@ namespace device {
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
// clang-format off
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
,
1
,
2
);
// for ADD
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
3
);
// for ADD
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
0
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
0
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
,
1
,
2
);
// for AVG
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
3
);
// for AVG
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
,
1
,
2
);
// for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
5
,
0
,
0
,
2
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
0
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
3
);
// for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
//
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
4
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
4
,
1
);
ADD_BLOCKWISE_INST_REF_BY_ID
(
half_t
,
float
,
half_t
,
7
,
0
,
0
,
2
,
1
);
// clang-format on
// clang-format on
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment