Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
33db3d8f
"driver/src/tensor.cpp" did not exist on "4957d5a399a1c3f6bcf812c9e2fa104ed0ea7742"
Commit
33db3d8f
authored
Sep 23, 2021
by
ltqin
Browse files
V4R4R2XDLATOMICNCHW fp16
parent
0c8aa120
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
204 additions
and
65 deletions
+204
-65
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
...ht_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
+132
-0
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
...ght_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
+56
-52
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+11
-10
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+5
-3
No files found.
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
33db3d8f
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
const
TensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_c_hi_wi_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I1
);
const
auto
Hi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I2
);
const
auto
Wi
=
in_n_c_hi_wi_grid_desc
.
GetLength
(
I3
);
const
auto
Ho
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_grid_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_grid_desc
.
GetLength
(
I3
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
C
*
Y
*
X
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmKBatch
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
N
,
Ho
*
Wo
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_n_c_hip_wip_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_c_y_ho_x_wo_grid_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
C
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
const
auto
in_gemmktotal_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
C
,
Y
,
X
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
C
*
Y
*
X
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
View file @
33db3d8f
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_
atomic_
nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TInWei
,
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
...
...
@@ -13,7 +14,8 @@ template <typename TInWei,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
,
typename
GemmKBatchType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
...
...
@@ -22,9 +24,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
Wei
>&
in_n_c_hi_wi
,
Tensor
<
T
In
Wei
>&
wei_k_c_y_x
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
GemmKBatchType
GemmKBatch
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
...
...
@@ -35,8 +38,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
Wei
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
T
In
Wei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
...
...
@@ -79,15 +82,17 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
constexpr
index_t
KBatch
=
64
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
);
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
...
...
@@ -95,24 +100,24 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemB
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemB
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
...
...
@@ -133,10 +138,10 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
1
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
...
...
@@ -146,9 +151,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
float
ave_time
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
Wei
,
TIn
,
TAcc
,
T
Out
,
T
Wei
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
...
...
@@ -185,20 +190,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
KBatch
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
,
&
clear_weight
);
false
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
,
&
clear_weight
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
33db3d8f
...
...
@@ -109,16 +109,17 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
);
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
);
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
33db3d8f
...
...
@@ -19,10 +19,10 @@
#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW
1
#define USE_CONV_WRW_V4R4R4_XDL_NHWC
1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW
0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC
0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
1
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC
0
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
enum
ConvBackwardWeightAlgo
...
...
@@ -335,6 +335,7 @@ int main(int argc, char* argv[])
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
<
in_data_t
,
wei_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
...
...
@@ -346,6 +347,7 @@ int main(int argc, char* argv[])
in
,
wei_device
,
out
,
k_batch
,
nrepeat
);
}
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment