Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
95710403
Commit
95710403
authored
Jun 02, 2021
by
Jing Zhang
Browse files
add kpack with incorrect results
parent
44078dba
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
220 additions
and
217 deletions
+220
-217
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+45
-23
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+147
-140
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+28
-54
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
95710403
...
@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -64,6 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
C
*
Y
*
X
;
// weight tensor
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
K
,
C
*
Y
*
X
)),
...
@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -71,6 +75,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
wei_gemmk0_gemmm_gemmk1_global_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
/
GemmKPack
,
GemmKPack
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// input tensor
// input tensor
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
in_n_c_hip_wip_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
...
@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -97,6 +108,13 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_global_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
/
GemmKPack
,
GemmKPack
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// output tensor
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
N
,
K
,
Ho
*
Wo
)),
...
@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -104,11 +122,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
assert
(
GemmM
=
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
)
)
;
const
auto
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
assert
(
GemmN
=
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
)
)
;
const
auto
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
const
auto
GemmK
0
=
wei_gemmk
0
_gemmm_
gemmk1_
global_desc
.
GetLength
(
I0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
0
%
GemmKPerBlock
==
0
);
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPack
>
{};
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
GemmMPerWave
,
GemmNPerWave
,
GemmKPack
>
{};
...
@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -129,22 +147,26 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
make_tuple
(
GemmM
/
Number
<
GemmMPerBlock
>
{},
GemmN
/
Number
<
GemmNPerBlock
>
{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
constexpr
auto
wei_gemmk_gemmm_global_iterator_hacks
=
constexpr
auto
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr
auto
in_gemmk_gemmn_global_iterator_hacks
=
constexpr
auto
in_gemmk0_gemmn_gemmk1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
=
constexpr
auto
in_gemmk
0
_gemmn_
gemmk1_
global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
,
0
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
// tensor hack for NKHW format
...
@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -158,15 +180,15 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
return
make_tuple
(
wei_gemmk
0
_gemmm_
gemmk1_
global_desc
,
in_gemmk_gemmn_global_desc
,
in_gemmk
0
_gemmn_
gemmk1_
global_desc
,
out_m0_m1_m2_n_global_desc
,
out_m0_m1_m2_n_global_desc
,
out_gemm_block_cluster_desc
,
out_gemm_block_cluster_desc
,
wei_gemmk_gemmm_global_iterator_hacks
,
wei_gemmk
0
_gemmm_
gemmk1_
global_iterator_hacks
,
in_gemmk_gemmn_global_iterator_hacks
,
in_gemmk
0
_gemmn_
gemmk1_
global_iterator_hacks
,
out_m0_m1_m2_n_global_iterator_hacks
,
out_m0_m1_m2_n_global_iterator_hacks
,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
wei_gemmk
0
_gemmm_
gemmk1_
global_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
in_gemmk
0
_gemmn_
gemmk1_
global_move_slice_window_iterator_hacks
);
}
}
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
95710403
This diff is collapsed.
Click to expand it.
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
95710403
...
@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -73,64 +73,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
const
auto
out_n_k_ho_wo_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
OutDesc
::
GetLengths
()));
sequence_to_tuple_of_number
(
OutDesc
::
GetLengths
()));
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_strides
=
sequence_to_tuple_of_number
(
ConvStrides
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
conv_dilations
=
sequence_to_tuple_of_number
(
ConvDilations
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_left_pads
=
sequence_to_tuple_of_number
(
InLeftPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#endif
#if 0
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlock = 64;
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 1;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#else
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
GemmKPack
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK
0
_GemmM
_GemmK1
=
Sequence
<
1
,
2
,
GemmKPack
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK
0
_GemmM
_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_
GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_
KPack
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK
0
_GemmN
_GemmK1
=
Sequence
<
1
,
2
,
GemmKPack
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK
0
_GemmN
_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_
GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_
KPack
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#endif
const
auto
descs
=
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
TInWei
,
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
TInWei
,
...
@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -167,21 +141,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmKPack
,
GemmKPack
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadSliceLengths_GemmK
0
_GemmM
_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK
0
_GemmM
_GemmK1
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
,
2
>
,
0
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_
GemmM
,
GemmABlockTransferDstScalarPerVector_
KPack
,
false
,
// don't move back src coordinate after threadwise copy
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadSliceLengths_GemmK
0
_GemmN
_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK
0
_GemmN
_GemmK1
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_
GemmN
,
GemmBBlockTransferDstScalarPerVector_
KPack
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
Sequence
<
2
,
3
,
0
,
1
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment