Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5a79ff1e
Commit
5a79ff1e
authored
Dec 02, 2021
by
Chao Liu
Browse files
add conv+bias+relu+add, but has register spill issue
parent
25343b48
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
214 additions
and
784 deletions
+214
-784
example/3_conv_xdl/conv_xdl.cpp
example/3_conv_xdl/conv_xdl.cpp
+5
-5
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
+114
-67
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add.hpp
...u_add/include/device_conv_fwd_xdl_bias_activation_add.hpp
+3
-3
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp
...evice_conv_fwd_xdl_bias_activation_add_nhwc_kyxc_nhwk.hpp
+92
-67
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk_two_extra_source_reduce.hpp
...e_conv_fwd_xdl_nhwc_kyxc_nhwk_two_extra_source_reduce.hpp
+0
-642
No files found.
example/3_conv_xdl/conv_xdl.cpp
View file @
5a79ff1e
...
@@ -51,11 +51,11 @@ using OutElementOp = Relu;
...
@@ -51,11 +51,11 @@ using OutElementOp = Relu;
using
DeviceConvFwdInstance
=
using
DeviceConvFwdInstance
=
// clang-format off
// clang-format off
//############################################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//############################################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//############################################| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//############################################| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//############################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//############################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceConvFwdXdl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InLayout
,
WeiLayout
,
OutLayout
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
ck
::
tensor_operation
::
device
::
DeviceConvFwdXdl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InLayout
,
WeiLayout
,
OutLayout
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
// clang-format on
// clang-format on
template
<
typename
TIn
,
template
<
typename
TIn
,
...
...
example/4_conv_xdl_bias_relu_add/conv_xdl_bias_relu_add.cpp
View file @
5a79ff1e
...
@@ -11,8 +11,8 @@
...
@@ -11,8 +11,8 @@
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
two_extra_source_reduce
.hpp"
#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
bias_activation_add
.hpp"
#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
two_extra_source_reduce
_nhwc_kyxc_nhwk.hpp"
#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
bias_activation_add
_nhwc_kyxc_nhwk.hpp"
struct
PassThrough
struct
PassThrough
{
{
...
@@ -23,13 +23,17 @@ struct PassThrough
...
@@ -23,13 +23,17 @@ struct PassThrough
}
}
};
};
struct
Relu
struct
Bias
Relu
Add
{
{
template
<
typename
T
>
template
<
typename
T
1
,
typename
T2
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
__host__
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
{
T
tmp
=
0.1
*
v
;
float
a
=
v0
+
v1
;
return
tmp
>
0
?
tmp
:
0
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
}
};
};
...
@@ -47,15 +51,15 @@ using OutLayout = ck::tensor_layout::convolution::NHWK;
...
@@ -47,15 +51,15 @@ using OutLayout = ck::tensor_layout::convolution::NHWK;
using
InElementOp
=
PassThrough
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
Relu
;
using
OutElementOp
=
Bias
Relu
Add
;
using
DeviceConvFwdInstance
=
using
DeviceConvFwdInstance
=
// clang-format off
// clang-format off
//
############################################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//####################
############################################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//
############################################| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//####################
############################################| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//
############################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//####################
############################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//
############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//####################
############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck
::
tensor_operation
::
device
::
DeviceConvFwdXdl_
two_extra_source_reduce
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InLayout
,
WeiLayout
,
OutLayout
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
ck
::
tensor_operation
::
device
::
DeviceConvFwdXdl_
bias_activation_add
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InLayout
,
WeiLayout
,
OutLayout
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
7
,
1
,
true
,
true
>
;
// clang-format on
// clang-format on
template
<
typename
TIn
,
template
<
typename
TIn
,
...
@@ -64,9 +68,11 @@ template <typename TIn,
...
@@ -64,9 +68,11 @@ template <typename TIn,
typename
InElementOp
,
typename
InElementOp
,
typename
WeiElementOp
,
typename
WeiElementOp
,
typename
OutElementOp
>
typename
OutElementOp
>
void
host_verify
(
const
Tensor
<
TIn
>&
in
,
void
host_reference_calculation
(
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
const
Tensor
<
TWei
>&
wei
,
const
Tensor
<
TWei
>&
wei_k_c_y_x
,
Tensor
<
TOut
>&
out
,
Tensor
<
TOut
>&
out_n_k_ho_wo
,
const
Tensor
<
TOut
>&
bias_k
,
const
Tensor
<
TOut
>&
resi_n_k_ho_wo
,
const
std
::
vector
<
ck
::
index_t
>&
conv_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
conv_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
in_left_pads
,
...
@@ -77,41 +83,43 @@ void host_verify(const Tensor<TIn>& in,
...
@@ -77,41 +83,43 @@ void host_verify(const Tensor<TIn>& in,
{
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
double
v
=
0
;
double
v
=
0
;
for
(
int
c
=
0
;
c
<
wei
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
int
c
=
0
;
c
<
wei
_k_c_y_x
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
int
y
=
0
;
y
<
wei
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
int
y
=
0
;
y
<
wei
_k_c_y_x
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
*
conv_strides
[
0
]
+
y
*
conv_dilations
[
0
]
-
in_left_pads
[
0
];
int
hi
=
ho
*
conv_strides
[
0
]
+
y
*
conv_dilations
[
0
]
-
in_left_pads
[
0
];
for
(
int
x
=
0
;
x
<
wei
.
mDesc
.
GetLengths
()[
3
];
++
x
)
for
(
int
x
=
0
;
x
<
wei
_k_c_y_x
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
*
conv_strides
[
1
]
+
x
*
conv_dilations
[
1
]
-
in_left_pads
[
1
];
int
wi
=
wo
*
conv_strides
[
1
]
+
x
*
conv_dilations
[
1
]
-
in_left_pads
[
1
];
if
(
hi
>=
0
&&
hi
<
in
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
in
_n_c_hi_wi
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in
.
mDesc
.
GetLengths
()[
3
])
wi
<
in
_n_c_hi_wi
.
mDesc
.
GetLengths
()[
3
])
{
{
v
+=
in_element_op
(
static_cast
<
const
double
>
(
in
(
n
,
c
,
hi
,
wi
)))
*
v
+=
in_element_op
(
static_cast
<
const
double
>
(
in
_n_c_hi_wi
(
n
,
c
,
hi
,
wi
)))
*
wei_element_op
(
static_cast
<
const
double
>
(
wei
(
k
,
c
,
y
,
x
)));
wei_element_op
(
static_cast
<
const
double
>
(
wei
_k_c_y_x
(
k
,
c
,
y
,
x
)));
}
}
}
}
}
}
}
}
out
(
n
,
k
,
ho
,
wo
)
=
out_element_op
(
v
);
out_n_k_ho_wo
(
n
,
k
,
ho
,
wo
)
=
out_element_op
(
v
,
bias_k
(
k
),
resi_n_k_ho_wo
(
n
,
k
,
ho
,
wo
));
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
f_nchw
,
out
.
mDesc
.
GetLengths
()[
0
],
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
0
],
out
.
mDesc
.
GetLengths
()[
1
],
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
1
],
out
.
mDesc
.
GetLengths
()[
2
],
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
2
],
out
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
out_n_k_ho_wo
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
if
(
argc
!=
4
)
//
if(argc != 4)
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
exit
(
0
);
//
exit(0);
}
}
const
bool
do_verification
=
std
::
stoi
(
argv
[
1
]);
const
bool
do_verification
=
std
::
stoi
(
argv
[
1
]);
...
@@ -119,6 +127,7 @@ int main(int argc, char* argv[])
...
@@ -119,6 +127,7 @@ int main(int argc, char* argv[])
const
int
nrepeat
=
std
::
stoi
(
argv
[
3
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
3
]);
// Conv shape
// Conv shape
#if 0
const ck::index_t N = 128;
const ck::index_t N = 128;
const ck::index_t K = 256;
const ck::index_t K = 256;
const ck::index_t C = 192;
const ck::index_t C = 192;
...
@@ -134,6 +143,23 @@ int main(int argc, char* argv[])
...
@@ -134,6 +143,23 @@ int main(int argc, char* argv[])
const ck::index_t in_left_pad_w = 1;
const ck::index_t in_left_pad_w = 1;
const ck::index_t in_right_pad_h = 1;
const ck::index_t in_right_pad_h = 1;
const ck::index_t in_right_pad_w = 1;
const ck::index_t in_right_pad_w = 1;
#else
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
4
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
5
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
6
]);
const
ck
::
index_t
Y
=
std
::
stoi
(
argv
[
7
]);
const
ck
::
index_t
X
=
std
::
stoi
(
argv
[
8
]);
const
ck
::
index_t
Hi
=
std
::
stoi
(
argv
[
9
]);
const
ck
::
index_t
Wi
=
std
::
stoi
(
argv
[
10
]);
const
ck
::
index_t
conv_stride_h
=
std
::
stoi
(
argv
[
11
]);
const
ck
::
index_t
conv_stride_w
=
std
::
stoi
(
argv
[
12
]);
const
ck
::
index_t
conv_dilation_h
=
std
::
stoi
(
argv
[
13
]);
const
ck
::
index_t
conv_dilation_w
=
std
::
stoi
(
argv
[
14
]);
const
ck
::
index_t
in_left_pad_h
=
std
::
stoi
(
argv
[
15
]);
const
ck
::
index_t
in_left_pad_w
=
std
::
stoi
(
argv
[
16
]);
const
ck
::
index_t
in_right_pad_h
=
std
::
stoi
(
argv
[
17
]);
const
ck
::
index_t
in_right_pad_w
=
std
::
stoi
(
argv
[
18
]);
#endif
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
YEff
=
(
Y
-
1
)
*
conv_dilation_h
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
const
ck
::
index_t
XEff
=
(
X
-
1
)
*
conv_dilation_w
+
1
;
...
@@ -178,9 +204,18 @@ int main(int argc, char* argv[])
...
@@ -178,9 +204,18 @@ int main(int argc, char* argv[])
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
Tensor
<
OutDataType
>
out_n_k_ho_wo_device_result
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
// bias: assume contiguous 1d vector
Tensor
<
OutDataType
>
bias_k
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
K
)})));
// residual: assume same layout as output tensor
Tensor
<
OutDataType
>
resi_n_k_ho_wo
(
f_host_tensor_descriptor
(
N
,
K
,
Ho
,
Wo
,
OutLayout
{}));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei_k_c_y_x: "
<<
wei_k_c_y_x
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_k_ho_wo: "
<<
out_n_k_ho_wo_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"bias_k: "
<<
bias_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"resi_n_k_ho_wo: "
<<
resi_n_k_ho_wo
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -188,26 +223,36 @@ int main(int argc, char* argv[])
...
@@ -188,26 +223,36 @@ int main(int argc, char* argv[])
case
1
:
case
1
:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
resi_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
wei_k_c_y_x
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
bias_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
resi_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
());
DeviceMem
resi_device_buf
(
sizeof
(
OutDataType
)
*
resi_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias_k
.
mData
.
data
());
resi_device_buf
.
ToDevice
(
resi_n_k_ho_wo
.
mData
.
data
());
// do GEMM
auto
conv
=
DeviceConvFwdInstance
{};
auto
conv
=
DeviceConvFwdInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
conv
.
MakeArgument
(
static_cast
<
const
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
OutDataType
*>
(
bias_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
OutDataType
*>
(
resi_device_buf
.
GetDeviceBuffer
()),
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -246,9 +291,11 @@ int main(int argc, char* argv[])
...
@@ -246,9 +291,11 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
host_
verify
(
in_n_c_hi_wi
,
host_
reference_calculation
(
in_n_c_hi_wi
,
wei_k_c_y_x
,
wei_k_c_y_x
,
out_n_k_ho_wo_host_result
,
out_n_k_ho_wo_host_result
,
bias_k
,
resi_n_k_ho_wo
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
two_extra_source_reduce
.hpp
→
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
bias_activation_add
.hpp
View file @
5a79ff1e
#ifndef DEVICE_CONV_FWD_XDL_
TWO_EXTRA_SOURCE_REDUCE
_HPP
#ifndef DEVICE_CONV_FWD_XDL_
BIAS_ACTIVATION_ADD
_HPP
#define DEVICE_CONV_FWD_XDL_
TWO_EXTRA_SOURCE_REDUCE
_HPP
#define DEVICE_CONV_FWD_XDL_
BIAS_ACTIVATION_ADD
_HPP
#include <iostream>
#include <iostream>
#include "device.hpp"
#include "device.hpp"
...
@@ -53,7 +53,7 @@ template <ck::index_t NDimSpatial,
...
@@ -53,7 +53,7 @@ template <ck::index_t NDimSpatial,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
bool
ABlockLdsAddExtraM
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
bool
BBlockLdsAddExtraN
>
struct
DeviceConvFwdXdl_
two_extra_source_reduce
;
struct
DeviceConvFwdXdl_
bias_activation_add
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
two_extra_source_reduce
_nhwc_kyxc_nhwk.hpp
→
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
bias_activation_add
_nhwc_kyxc_nhwk.hpp
View file @
5a79ff1e
#ifndef DEVICE_CONV_FWD_XDL_
TWO_EXTRA_SOURCE_REDUCE
_NHWC_KYXC_NHWK_HPP
#ifndef DEVICE_CONV_FWD_XDL_
BIAS_ACTIVATION_ADD
_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV_FWD_XDL_
TWO_EXTRA_SOURCE_REDUCE
_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV_FWD_XDL_
BIAS_ACTIVATION_ADD
_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <iostream>
#include "device.hpp"
#include "device.hpp"
...
@@ -9,8 +9,8 @@
...
@@ -9,8 +9,8 @@
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r
3
.hpp"
#include "gridwise_gemm_xdlops_v2r
5
.hpp"
#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
two_extra_source_reduce
.hpp"
#include "example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_
bias_activation_add
.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -51,7 +51,7 @@ template <typename InDataType,
...
@@ -51,7 +51,7 @@ template <typename InDataType,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
bool
ABlockLdsAddExtraM
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
>
bool
BBlockLdsAddExtraN
>
struct
DeviceConvFwdXdl_
two_extra_source_reduce
<
struct
DeviceConvFwdXdl_
bias_activation_add
<
2
,
// ck::index_t NDimSpatial,
2
,
// ck::index_t NDimSpatial,
InDataType
,
// typename InDataType,
InDataType
,
// typename InDataType,
WeiDataType
,
// typename WeiDataType,
WeiDataType
,
// typename WeiDataType,
...
@@ -108,6 +108,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -108,6 +108,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
constexpr
auto
GemmK1Number
=
K1Number
;
...
@@ -153,6 +154,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -153,6 +154,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
const
auto
GemmMPad
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
)
-
GemmMRaw
;
const
auto
GemmMPad
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
)
-
GemmMRaw
;
const
auto
GemmM
=
GemmMRaw
+
GemmMPad
;
assert
(
GemmK
%
GemmK1Number
==
0
);
assert
(
GemmK
%
GemmK1Number
==
0
);
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
...
@@ -236,9 +239,18 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -236,9 +239,18 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// C0: bias tensor: assume a contiguous vector
const
auto
bias_grid_desc_gemmm_gemmn
=
make_naive_tensor_descriptor
(
make_tuple
(
GemmM
,
GemmN
),
make_tuple
(
0
,
1
));
// C1: residual tensor: assume same layout as output tensor
const
auto
resi_grid_desc_gemmm_gemmn
=
out_gemmm_gemmn_grid_desc
;
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmm_gemmn_grid_desc
);
out_gemmm_gemmn_grid_desc
,
bias_grid_desc_gemmm_gemmn
,
resi_grid_desc_gemmm_gemmn
);
}
}
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
...
@@ -247,6 +259,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -247,6 +259,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
C0GridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I3
])
>
;
using
C1GridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I4
])
>
;
// TODO remove these hacks
// TODO remove these hacks
static
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
static
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
...
@@ -289,7 +303,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -289,7 +303,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
static
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
static
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r
3
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r
5
<
BlockSize
,
BlockSize
,
ABDataType
,
// TODO: distinguish A/B datatype
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
...
@@ -298,6 +312,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -298,6 +312,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
...
@@ -340,6 +356,12 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -340,6 +356,12 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
using
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C1GridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
// Argument
// Argument
...
@@ -348,6 +370,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -348,6 +370,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_bias_grid
,
const
OutDataType
*
p_resi_grid
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -366,10 +390,16 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -366,10 +390,16 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
:
p_a_grid_
{
p_in_grid
},
:
p_a_grid_
{
p_in_grid
},
p_b_grid_
{
p_wei_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
p_c_grid_
{
p_out_grid
},
p_c0_grid_
{
p_bias_grid
},
p_c1_grid_
{
p_resi_grid
},
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
c0_grid_desc_m_n_
{},
c1_grid_desc_m_n_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -377,7 +407,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -377,7 +407,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
wei_element_op_
{
wei_element_op
},
wei_element_op_
{
wei_element_op
},
out_element_op_
{
out_element_op
}
out_element_op_
{
out_element_op
}
{
{
const
auto
descs
=
DeviceConvFwdXdl_
two_extra_source_reduce
::
const
auto
descs
=
DeviceConvFwdXdl_
bias_activation_add
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
K
,
K
,
C
,
C
,
...
@@ -392,6 +422,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -392,6 +422,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
c_grid_desc_m_n_
=
descs
[
I2
];
c0_grid_desc_m_n_
=
descs
[
I3
];
c1_grid_desc_m_n_
=
descs
[
I4
];
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
...
@@ -399,6 +431,12 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -399,6 +431,12 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c0_grid_desc_m_n_
);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c1_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
}
}
...
@@ -407,10 +445,16 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -407,10 +445,16 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
const
CDataType
*
p_c0_grid_
;
const
CDataType
*
p_c1_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -422,7 +466,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -422,7 +466,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceConvFwdXdl_
two_extra_source_reduce
::
Argument
;
using
Argument
=
DeviceConvFwdXdl_
bias_activation_add
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
...
@@ -437,6 +481,12 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -437,6 +481,12 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c0_grid_desc_m_n_{ "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c1_grid_desc_m_n_{ "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
...
@@ -446,7 +496,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -446,7 +496,7 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
arg
.
N01_
))
arg
.
N01_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r
3
has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r
5
has invalid setting"
);
}
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
...
@@ -459,18 +509,22 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -459,18 +509,22 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r
3
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r
5
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
DeviceConvFwdXdl_two_extra_source_reduce
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_two_extra_source_reduce
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_two_extra_source_reduce
::
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
DeviceConvFwdXdl_bias_activation_add
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceConvFwdXdl_
two_extra_source_reduce
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceConvFwdXdl_
bias_activation_add
::
Block2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
...
@@ -481,9 +535,13 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -481,9 +535,13 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -491,18 +549,22 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -491,18 +549,22 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r
3
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r
5
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
remove_reference_t
<
DeviceConvFwdXdl_two_extra_source_reduce
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_two_extra_source_reduce
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceConvFwdXdl_two_extra_source_reduce
::
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
DeviceConvFwdXdl_bias_activation_add
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceConvFwdXdl_bias_activation_add
::
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceConvFwdXdl_
two_extra_source_reduce
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceConvFwdXdl_
bias_activation_add
::
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
...
@@ -513,9 +575,13 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -513,9 +575,13 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -556,6 +622,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -556,6 +622,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_bias_grid
,
const
OutDataType
*
p_resi_grid
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
K
,
ck
::
index_t
C
,
ck
::
index_t
C
,
...
@@ -573,6 +641,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -573,6 +641,8 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
p_bias_grid
,
p_resi_grid
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -591,51 +661,6 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
...
@@ -591,51 +661,6 @@ struct DeviceConvFwdXdl_two_extra_source_reduce<
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
// namespace device
};
// namespace device
}
// namespace device
}
// namespace device
...
...
example/4_conv_xdl_bias_relu_add/include/device_conv_fwd_xdl_nhwc_kyxc_nhwk_two_extra_source_reduce.hpp
deleted
100644 → 0
View file @
25343b48
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment