Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5cfd01fd
Commit
5cfd01fd
authored
Aug 20, 2021
by
ltqin
Browse files
format
parent
1e9c511c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
40 deletions
+34
-40
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+11
-15
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+23
-25
No files found.
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
View file @
5cfd01fd
...
...
@@ -77,15 +77,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
...
...
@@ -93,13 +93,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
5cfd01fd
...
...
@@ -14,13 +14,12 @@
#include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#define USE_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
enum
ConvBackwardWeightAlgo
{
V4R4R2XDLNCHW
,
V4R4R2XDLNCHW
,
};
int
main
(
int
argc
,
char
*
argv
[])
...
...
@@ -44,12 +43,12 @@ int main(int argc, char* argv[])
exit
(
1
);
}
const
ConvTensorLayout
layout
=
static_cast
<
ConvTensorLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
ConvBackwardWeightAlgo
algo
=
static_cast
<
ConvBackwardWeightAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
ConvTensorLayout
layout
=
static_cast
<
ConvTensorLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
ConvBackwardWeightAlgo
algo
=
static_cast
<
ConvBackwardWeightAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
index_t
N
=
std
::
stoi
(
argv
[
7
]);
const
index_t
K
=
std
::
stoi
(
argv
[
8
]);
...
...
@@ -81,12 +80,12 @@ int main(int argc, char* argv[])
exit
(
1
);
}
const
ConvTensorLayout
layout
=
static_cast
<
ConvTensorLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
ConvBackwardWeightAlgo
algo
=
static_cast
<
ConvBackwardWeightAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
ConvTensorLayout
layout
=
static_cast
<
ConvTensorLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
ConvBackwardWeightAlgo
algo
=
static_cast
<
ConvBackwardWeightAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
...
...
@@ -245,7 +244,6 @@ int main(int argc, char* argv[])
in_right_pads_dev
);
};
#if USE_CONV_WRW_V4R4R2_XDL_NCHW
if
(
algo
==
ConvBackwardWeightAlgo
::
V4R4R2XDLNCHW
)
{
...
...
@@ -257,8 +255,8 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
<
in_data_t
,
acc_data_t
,
out_data_t
>
(
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I2
],
...
...
@@ -275,14 +273,14 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
host_direct_convolution_backward_weights
(
out
,
in
,
wei_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
);
host_direct_convolution_backward_weights
(
out
,
in
,
wei_host
,
make_tuple
(
conv_stride_h
,
conv_stride_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
);
check_error
(
wei_host
,
wei_device
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment