Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
074f7410
Commit
074f7410
authored
Sep 09, 2021
by
ltqin
Browse files
nhwc atomic is able to run
parent
f907ba09
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
55 deletions
+64
-55
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+50
-41
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+14
-14
No files found.
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
074f7410
...
...
@@ -2,7 +2,7 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r
3
.hpp"
#include "driver_gemm_xdlops_v2r
4
.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
...
...
@@ -91,20 +91,21 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
2
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
2
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
32
,
2
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
KBatch
=
32
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -150,21 +151,25 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -185,19 +190,22 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
std
::
function
<
void
()
>
clear_weight
=
[
&
wei_k_y_x_c_device_buf
,
&
wei_k_y_x_c
]()
{
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r
3
<
float
ave_time
=
driver_gemm_xdlops_v2r
4
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
...
...
@@ -211,17 +219,17 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
Sequence
<
0
,
1
,
3
,
2
>
,
Sequence
<
0
,
1
,
3
,
2
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
...
...
@@ -233,19 +241,20 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
false
,
// CAccessOrderMRepeatNRepeat
KBatch
>
(
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
,
&
clear_weight
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
074f7410
...
...
@@ -350,20 +350,20 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
<
in_data_t
,
acc
_data_t
,
out
_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
nrepeat
);
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
<
in
_data_t
,
acc
_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I6
],
in
,
wei_device
,
out
,
nrepeat
);
}
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment