Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c8aa120
Commit
0c8aa120
authored
Sep 23, 2021
by
ltqin
Browse files
V4R4R4XDLNHWC fp16
parent
9da60a34
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
172 additions
and
38 deletions
+172
-38
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
...ht_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
+132
-0
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
...rd_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
+26
-26
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+2
-2
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...ard_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+10
-9
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+2
-1
No files found.
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
0c8aa120
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template
<
typename
...
In
,
typename
...
Wei
,
typename
...
Out
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
typename
GemmKBatchType
>
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
Wei
...
>&
wei_k_y_x_c_grid_desc
,
const
TensorDescriptor
<
Out
...
>&
out_n_ho_wo_k_grid_desc
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
GemmKBatchType
GemmKBatch
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
GemmK1
=
Number
<
GemmK1Value
>
{};
const
auto
N
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I3
);
const
auto
K
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I3
);
const
auto
Hi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I1
);
const
auto
Wi
=
in_n_hi_wi_c_grid_desc
.
GetLength
(
I2
);
const
auto
Ho
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I1
);
const
auto
Wo
=
out_n_ho_wo_k_grid_desc
.
GetLength
(
I2
);
const
auto
Y
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I1
);
const
auto
X
=
wei_k_y_x_c_grid_desc
.
GetLength
(
I2
);
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
InLeftPadH
=
in_left_pads
[
I0
];
const
auto
InLeftPadW
=
in_left_pads
[
I1
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmKBatch
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmktotal_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: output tensor
const
auto
out_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
)),
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
Y
*
X
*
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
}
}
// namespace ck
#endif
composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp
View file @
0c8aa120
...
@@ -20,8 +20,7 @@ template <typename... In,
...
@@ -20,8 +20,7 @@ template <typename... In,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
,
typename
InRightPads
,
index_t
GemmK1Value
,
index_t
GemmK1Value
>
typename
GemmKBatchType
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
const
TensorDescriptor
<
In
...
>&
in_n_hi_wi_c_grid_desc
,
...
@@ -31,8 +30,7 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
...
@@ -31,8 +30,7 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
Number
<
GemmK1Value
>
,
Number
<
GemmK1Value
>
)
GemmKBatchType
GemmKBatch
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -66,11 +64,10 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
...
@@ -66,11 +64,10 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmM
=
Y
*
X
*
C
;
const
auto
GemmN
=
K
;
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmKBatch
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
const
auto
GemmK0
=
GemmK
/
GemmK1
;
// A: input tensor
// A: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
...
@@ -91,30 +88,33 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
...
@@ -91,30 +88,33 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmk
total
_gemmm_grid_desc
=
const
auto
in_gemmk_gemmm_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
)),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
in_gemmk
total
_gemmm_grid_desc
,
transform_tensor_descriptor
(
in_gemmk_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmM
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// B: output tensor
// B: output tensor
const
auto
out_gemmktotal_gemmn_grid_desc
=
const
auto
out_gemmk_gemmn_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_pass_through_transform
(
K
)),
const
auto
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: weight tensor
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
...
@@ -123,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
...
@@ -123,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
return
make_tuple
(
in_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
,
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm_gemmn_grid_desc
);
}
}
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
0c8aa120
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_
atomic_
nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
template
<
typename
TIn
,
...
@@ -109,7 +109,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
...
@@ -109,7 +109,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad
(
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r4_
atomic_
nhwc_kyxc_nhwk_pad
(
in_n_hi_wi_c_desc
,
in_n_hi_wi_c_desc
,
wei_k_y_x_c_desc
,
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
out_n_ho_wo_k_desc
,
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
0c8aa120
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TAcc
,
typename
TOut
,
typename
TOut
,
typename
InLengths
,
typename
InLengths
,
...
@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -22,8 +23,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
Wei
>&
in_n_hi_wi_c
,
const
Tensor
<
TIn
>&
in_n_hi_wi_c
,
Tensor
<
T
In
Wei
>&
wei_k_y_x_c
,
Tensor
<
TWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
...
@@ -36,8 +37,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -36,8 +37,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
Wei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TIn
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
T
In
Wei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
...
@@ -194,9 +195,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -194,9 +195,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
{
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
BlockSize
,
TIn
Wei
,
TIn
,
TAcc
,
TAcc
,
T
Out
,
T
Wei
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
in_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
...
@@ -234,9 +235,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
...
@@ -234,9 +235,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
TIn
Wei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
>
(
static_cast
<
TIn
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
In
Wei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
0c8aa120
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#define USE_DYNAMIC_MODE 1
#define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
#define USE_CONV_WRW_V4R4R4_XDL_NHWC
0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC
1
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1
...
@@ -306,6 +306,7 @@ int main(int argc, char* argv[])
...
@@ -306,6 +306,7 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nhwc
();
const
auto
tmp
=
f_make_for_device_nhwc
();
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
<
in_data_t
,
wei_data_t
,
acc_data_t
,
acc_data_t
,
out_data_t
>
(
out_data_t
>
(
tmp
[
I0
],
tmp
[
I0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment