Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
47b26f0f
Commit
47b26f0f
authored
Sep 23, 2021
by
ltqin
Browse files
v4r4r4 fp16
parent
3c150a8f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
10 deletions
+12
-10
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+10
-9
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+2
-1
No files found.
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
47b26f0f
...
...
@@ -4,7 +4,8 @@
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TInWei
,
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
...
...
@@ -23,8 +24,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
Wei
>&
in_n_hi_wi_c
,
Tensor
<
T
In
Wei
>&
wei_k_y_x_c
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
GemmKBatchType
GemmKBatch
,
ck
::
index_t
nrepeat
)
...
...
@@ -38,8 +39,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
Wei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
T
In
Wei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
...
...
@@ -176,9 +177,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
{
float
ave_time
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
Wei
,
TIn
,
TAcc
,
T
Out
,
T
Wei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
),
...
...
@@ -216,9 +217,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
decltype
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TIn
Wei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
In
Wei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
47b26f0f
...
...
@@ -22,7 +22,7 @@
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
1
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
enum
ConvBackwardWeightAlgo
...
...
@@ -360,6 +360,7 @@ int main(int argc, char* argv[])
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment