Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a1841d55
Commit
a1841d55
authored
Aug 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into lwpck-367
parents
127bf7f4
500fa995
Changes
373
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1059 additions
and
319 deletions
+1059
-319
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
...ibrary/reference_tensor_operation/cpu/reference_cgemm.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+128
-107
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+97
-85
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+105
-83
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
...nsor_operation/cpu/reference_conv_fwd_bias_activation.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
..._operation/cpu/reference_conv_fwd_bias_activation_add.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
...reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp
...e_tensor_operation/cpu/reference_gemm_bias_activation.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp
...nsor_operation/cpu/reference_gemm_bias_activation_add.hpp
+1
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+50
-5
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
...brary/tensor_operation_instance/gpu/contraction_scale.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+270
-0
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
...or_operation_instance/gpu/convolution_backward_weight.hpp
+230
-0
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
...ary/tensor_operation_instance/gpu/convolution_forward.hpp
+128
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
...y/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
+31
-19
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
a1841d55
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp
View file @
a1841d55
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
a1841d55
...
...
@@ -8,22 +8,24 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template
<
typename
InDataType
,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdData
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -73,36 +75,45 @@ struct ReferenceConvBwdData : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
weight_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
))
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
if
constexpr
(
NDimSpatial
==
1
)
{
auto
f_ncw
=
[
&
](
auto
g
,
auto
n
,
auto
c
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
GetLengths
()[
1
];
std
::
size_t
X
=
arg
.
weight_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
3
];
AccDataType
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
]);
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
wo
)));
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
x
)));
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
...
...
@@ -110,66 +121,72 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
arg
.
in_element_op_
(
v_acc
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
g
,
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])(
arg
.
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
auto
f_nchw
=
[
&
](
auto
g
,
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
GetLengths
()[
1
];
std
::
size_t
Y
=
arg
.
weight_
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
weight_
.
GetLengths
()[
4
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
4
];
AccDataType
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
auto
h_tmp
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
hi
)
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
]);
auto
h_tmp
=
static_cas
t
<
ck
::
long_index_t
>
(
hi
)
+
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
]);
if
(
h_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
auto
ho
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
h_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
auto
ho
=
static_cas
t
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
auto
wo
=
static_cas
t
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)));
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
y
,
x
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
...
...
@@ -180,90 +197,91 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType
v_in
;
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
],
arg
.
input_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
auto
f_ncdhw
=
[
&
](
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Z
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Y
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
X
=
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
auto
f_ncdhw
=
[
&
](
auto
g
,
auto
n
,
auto
c
,
auto
di
,
auto
hi
,
auto
wi
)
{
std
::
size_t
K
=
arg
.
weight_
.
GetLengths
()[
1
];
std
::
size_t
Z
=
arg
.
weight_
.
GetLengths
()[
3
];
std
::
size_t
Y
=
arg
.
weight_
.
GetLengths
()[
4
];
std
::
size_t
X
=
arg
.
weight_
.
GetLengths
()[
5
];
std
::
size_t
Do
=
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
Ho
=
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
std
::
size_t
Do
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
4
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
5
];
AccDataType
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
z
=
0
;
z
<
Z
;
++
z
)
{
auto
d_tmp
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
di
)
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
]);
auto
d_tmp
=
static_cas
t
<
ck
::
long_index_t
>
(
di
)
+
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
]);
if
(
d_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
auto
do_
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
d_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
auto
do_
=
static_cas
t
<
ck
::
long_index_t
>
(
d_tmp
)
/
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
{
for
(
std
::
size_t
y
=
0
;
y
<
Y
;
++
y
)
{
auto
h_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
hi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
hi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
auto
ho
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
h_tmp
)
/
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
auto
ho
=
static_cas
t
<
ck
::
long_index_t
>
(
h_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
for
(
std
::
size_t
x
=
0
;
x
<
X
;
++
x
)
{
auto
w_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
auto
w_tmp
=
static_cast
<
ck
::
long_index_t
>
(
wi
)
+
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
2
]
==
0
)
{
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_convert
<
ck
::
long_index_t
>
(
auto
wo
=
static_cast
<
ck
::
long_index_t
>
(
w_tmp
)
/
static_cast
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
2
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
float
v_out
=
0
;
float
v_wei
=
0
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
AccDataType
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_out
*
v_wei
;
}
...
...
@@ -277,17 +295,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
AccDataType
v_in
;
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncdhw
,
arg
.
input_
.
mDesc
.
GetLengths
()[
0
],
arg
.
input_
.
mDesc
.
GetLengths
()[
1
],
arg
.
input_
.
mDesc
.
GetLengths
()[
2
],
arg
.
input_
.
mDesc
.
GetLengths
()[
3
],
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])(
arg
.
input_
.
GetLengths
()[
0
],
arg
.
input_
.
GetLengths
()[
1
],
arg
.
input_
.
GetLengths
()[
2
],
arg
.
input_
.
GetLengths
()[
3
],
arg
.
input_
.
GetLengths
()[
4
],
arg
.
input_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_b
ackwar
d_weight.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_b
w
d_weight.hpp
View file @
a1841d55
...
...
@@ -7,21 +7,25 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
template
<
typename
InDataType
,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -71,156 +75,162 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
weight_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
))
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
if
constexpr
(
NDimSpatial
==
1
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
auto
f_kcx
=
[
&
](
auto
k
,
auto
c
,
auto
x
)
{
auto
f_kcx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
];
++
n
)
{
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
3
];
++
wo
)
{
auto
wi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_
stride
s_
[
I
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilation
s_
[
I
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
arg
.
weight_
(
g
,
k
,
c
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
])(
arg
.
weight_
.
GetLengths
()[
0
],
arg
.
weight_
.
GetLengths
()[
1
],
arg
.
weight_
.
GetLengths
()[
2
],
arg
.
weight_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
];
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
3
];
++
ho
)
{
auto
hi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_
stride
s_
[
I
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilation
s_
[
I
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
4
];
++
wo
)
{
auto
wi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I
1
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilation
s_
[
I
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
static_cas
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
1
])
;
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
])(
arg
.
weight_
.
GetLengths
()[
0
],
arg
.
weight_
.
GetLengths
()[
1
],
arg
.
weight_
.
GetLengths
()[
2
],
arg
.
weight_
.
GetLengths
()[
3
],
arg
.
weight_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
f_kczyx
=
[
&
](
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
auto
f_kczyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
];
++
n
)
{
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
do_
)
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
GetLengths
()[
3
];
++
do_
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
ho
)
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
4
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
++
wo
)
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
5
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I2
]);
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
arg
.
input_
.
GetLengths
()[
3
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
arg
.
input_
.
GetLengths
()[
4
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
arg
.
input_
.
GetLengths
()[
5
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
arg
.
output_
(
g
,
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
...
...
@@ -228,19 +238,21 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kczyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
])(
arg
.
weight_
.
GetLengths
()[
0
],
arg
.
weight_
.
GetLengths
()[
1
],
arg
.
weight_
.
GetLengths
()[
2
],
arg
.
weight_
.
GetLengths
()[
3
],
arg
.
weight_
.
GetLengths
()[
4
],
arg
.
weight_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
a1841d55
...
...
@@ -8,7 +8,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -17,9 +17,10 @@ namespace host {
//
// @brief Reference implementation for forward convolution.
//
// @paragraph Supports both NCHW as well as NHWC formats (and their respective
// counterparts for weight and output) as long as tensor descriptor
// lengths is in NCHW.
// @paragraph
// Tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order
//
// @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type.
...
...
@@ -28,16 +29,20 @@ namespace host {
// operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation.
// @tparam N
um
DimSpatial Number of spatial dimensions.
// @tparam NDimSpatial Number of spatial dimensions.
//
template
<
typename
InDataType
,
// input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order
// output descriptor in [G, N, K, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
std
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
typename
std
::
enable_if
<
NDimSpatial
>
=
1
&&
NDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvFwd
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -86,29 +91,37 @@ struct ReferenceConvFwd : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
if
constexpr
(
NumDimSpatial
==
1
)
if
(
!
(
arg
.
input_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
weight_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
&&
arg
.
output_
.
GetNumOfDimension
()
==
NDimSpatial
+
3
))
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
throw
std
::
runtime_error
(
"wrong! inconsistent dimension"
);
}
if
constexpr
(
NDimSpatial
==
1
)
{
auto
func
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
3
];
++
x
)
{
auto
wi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_
stride
s_
[
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilation
s_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
x
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
...
...
@@ -118,50 +131,53 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_ncw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
])(
make_ParallelTensorFunctor
(
func
,
arg
.
output_
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
auto
f
_
nc
hw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
GetLengths
()[
3
];
++
y
)
{
auto
hi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_
stride
s_
[
0
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilation
s_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_
dilation
s_
[
0
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pad
s_
[
0
])
;
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
4
];
++
x
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
y
,
x
)));
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
}
...
...
@@ -171,64 +187,65 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
],
arg
.
output_
.
mDesc
.
GetLengths
()[
3
])(
make_ParallelTensorFunctor
(
func
,
arg
.
output_
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
],
arg
.
output_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
N
um
DimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
auto
f
_
nc
hw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
for
(
std
::
size_t
z
=
0
;
z
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
z
)
for
(
std
::
size_t
z
=
0
;
z
<
arg
.
weight_
.
GetLengths
()[
3
];
++
z
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
y
)
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
GetLengths
()[
4
];
++
y
)
{
auto
hi
=
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
ck
::
type_conver
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
++
x
)
static_cas
t
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cas
t
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cas
t
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
5
];
++
x
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
arg
.
input_
.
GetLengths
()[
3
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
arg
.
input_
.
GetLengths
()[
4
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
arg
.
input_
.
GetLengths
()[
5
])
{
float
v_in
;
float
v_wei
;
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
arg
.
wei_element_op_
(
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
v_acc
+=
v_in
*
v_wei
;
}
}
...
...
@@ -239,15 +256,17 @@ struct ReferenceConvFwd : public device::BaseOperator
float
v_out
;
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
output_
(
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
make_ParallelTensorFunctor
(
f_nchw
,
arg
.
output_
.
mDesc
.
GetLengths
()[
0
],
arg
.
output_
.
mDesc
.
GetLengths
()[
1
],
arg
.
output_
.
mDesc
.
GetLengths
()[
2
],
arg
.
output_
.
mDesc
.
GetLengths
()[
3
],
arg
.
output_
.
mDesc
.
GetLengths
()[
4
])(
make_ParallelTensorFunctor
(
func
,
arg
.
output_
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
],
arg
.
output_
.
GetLengths
()[
4
],
arg
.
output_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -267,7 +286,10 @@ struct ReferenceConvFwd : public device::BaseOperator
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
NDimSpatial
>=
1
&&
NDimSpatial
<=
3
;
}
static
auto
MakeArgument
(
const
Tensor
<
InDataType
>&
input
,
const
Tensor
<
WeiDataType
>&
weight
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
View file @
a1841d55
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
View file @
a1841d55
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
a1841d55
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
View file @
a1841d55
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp
View file @
a1841d55
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp
View file @
a1841d55
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
a1841d55
...
...
@@ -9,8 +9,8 @@
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
host_tensor
/host_tensor_generator.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
a1841d55
...
...
@@ -9,8 +9,8 @@
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/
host_tensor
/host_tensor.hpp"
#include "ck/library/
host_tensor
/host_tensor_generator.hpp"
#include "ck/library/
utility
/host_tensor.hpp"
#include "ck/library/
utility
/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
a1841d55
...
...
@@ -10,22 +10,67 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
// aliasing, for commonly used type
// aliasing, for commonly used
data
type
using
F64
=
double
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
E
MPTY_TUPLE
=
ck
::
Tuple
<>
;
using
E
mpty_Tuple
=
ck
::
Tuple
<>
;
using
F16_T
UPLE
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_T
UPLE
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F16_T
uple
=
ck
::
Tuple
<
F16
>
;
using
F16_F16_T
uple
=
ck
::
Tuple
<
F16
,
F16
>
;
using
F32_T
UPLE
=
ck
::
Tuple
<
F32
>
;
using
F32_T
uple
=
ck
::
Tuple
<
F32
>
;
// GEMM layout
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Row_Tuple
=
ck
::
Tuple
<
Row
>
;
using
Row_Row_Tuple
=
ck
::
Tuple
<
Row
,
Row
>
;
// Conv layout
//
using
NWC
=
ck
::
tensor_layout
::
convolution
::
NWC
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
KXC
=
ck
::
tensor_layout
::
convolution
::
KXC
;
using
KYXC
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
KZYXC
=
ck
::
tensor_layout
::
convolution
::
KZYXC
;
using
NWK
=
ck
::
tensor_layout
::
convolution
::
NWK
;
using
NHWK
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
using
NDHWK
=
ck
::
tensor_layout
::
convolution
::
NDHWK
;
//
using
GNWC
=
ck
::
tensor_layout
::
convolution
::
GNWC
;
using
GNHWC
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
GNDHWC
=
ck
::
tensor_layout
::
convolution
::
GNDHWC
;
using
GKXC
=
ck
::
tensor_layout
::
convolution
::
GKXC
;
using
GKYXC
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
using
GKZYXC
=
ck
::
tensor_layout
::
convolution
::
GKZYXC
;
using
GNWK
=
ck
::
tensor_layout
::
convolution
::
GNWK
;
using
GNHWK
=
ck
::
tensor_layout
::
convolution
::
GNHWK
;
using
GNDHWK
=
ck
::
tensor_layout
::
convolution
::
GNDHWK
;
//
using
NWGC
=
ck
::
tensor_layout
::
convolution
::
NWGC
;
using
NHWGC
=
ck
::
tensor_layout
::
convolution
::
NHWGC
;
using
NDHWGC
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
KXGC
=
ck
::
tensor_layout
::
convolution
::
KXGC
;
using
KYXGC
=
ck
::
tensor_layout
::
convolution
::
KYXGC
;
using
KZYXGC
=
ck
::
tensor_layout
::
convolution
::
KZYXGC
;
using
NWGK
=
ck
::
tensor_layout
::
convolution
::
NWGK
;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
View file @
a1841d55
...
...
@@ -25,7 +25,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn
2
,
F32
,
F32
,
F32_T
UPLE
,
F32_T
uple
,
F32
,
PassThrough
,
PassThrough
,
...
...
@@ -37,7 +37,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn
2
,
F32
,
F32
,
F32_T
UPLE
,
F32_T
uple
,
F32
,
PassThrough
,
PassThrough
,
...
...
@@ -49,7 +49,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn
2
,
F32
,
F32
,
F32_T
UPLE
,
F32_T
uple
,
F32
,
PassThrough
,
PassThrough
,
...
...
@@ -61,7 +61,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
2
,
F32
,
F32
,
F32_T
UPLE
,
F32_T
uple
,
F32
,
PassThrough
,
PassThrough
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
View file @
a1841d55
...
...
@@ -25,7 +25,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instanc
2
,
F32
,
F32
,
E
MPTY_TUPLE
,
E
mpty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
...
...
@@ -37,7 +37,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instanc
2
,
F32
,
F32
,
E
MPTY_TUPLE
,
E
mpty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
...
...
@@ -49,7 +49,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instanc
2
,
F32
,
F32
,
E
MPTY_TUPLE
,
E
mpty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
...
...
@@ -61,7 +61,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
2
,
F32
,
F32
,
E
MPTY_TUPLE
,
E
mpty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
0 → 100644
View file @
a1841d55
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward data
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward data
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvBwdData
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvBwdData
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
0 → 100644
View file @
a1841d55
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
0 → 100644
View file @
a1841d55
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv2d forward
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
View file @
a1841d55
...
...
@@ -19,49 +19,53 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
mk_k
n_mn_instances
(
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_mk_kn_mn_m
n_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_T
UPLE
,
F16_F16_T
uple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
mk_nk
_mn_instances
(
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_mk_nk_mn_mn
_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_T
UPLE
,
F16_F16_T
uple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
km_k
n_mn_instances
(
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_km_kn_mn_m
n_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_T
UPLE
,
F16_F16_T
uple
,
F16
,
PassThrough
,
PassThrough
,
AddAddFastGelu
>>>&
);
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
km_nk
_mn_instances
(
void
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_km_nk_mn_mn
_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_T
UPLE
,
F16_F16_T
uple
,
F16
,
PassThrough
,
PassThrough
,
...
...
@@ -70,7 +74,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
// GEMM + Add + Add + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
...
...
@@ -79,7 +85,8 @@ template <typename ALayout,
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
...
...
@@ -90,7 +97,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DELayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
...
...
@@ -108,27 +116,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
mk_k
n_mn_instances
(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_mk_kn_mn_m
n_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
DELayout
,
Row
>
)
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
mk_nk
_mn_instances
(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_mk_nk_mn_mn
_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DELayout
,
Row
>
)
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
km_k
n_mn_instances
(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_km_kn_mn_m
n_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
DELayout
,
Row
>
)
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
km_nk
_mn_instances
(
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_
f16_f16_km_nk_mn_mn
_mn_instances
(
op_ptrs
);
}
}
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment