Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
78f6584a
Commit
78f6584a
authored
Jun 04, 2021
by
Jing Zhang
Browse files
debugging
parent
17daf766
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
12 deletions
+37
-12
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+34
-6
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+3
-6
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
78f6584a
...
...
@@ -64,9 +64,10 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
InRightPadH
=
in_right_pads
[
I0
];
const
auto
InRightPadW
=
in_right_pads
[
I1
];
const
auto
GemmM
=
K
;
const
auto
GemmN
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
C
*
Y
*
X
;
const
auto
GemmM
=
K
;
const
auto
GemmN
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
C
*
Y
*
X
;
const
auto
GemmK0
=
GemmK
/
GemmKPack
;
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
...
...
@@ -77,7 +78,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
wei_gemmk0_gemmm_gemmk1_global_desc
=
transform_dynamic_tensor_descriptor
(
wei_gemmk_gemmm_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
/
GemmKPack
,
GemmKPack
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
0
,
GemmKPack
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -110,7 +111,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
in_gemmk0_gemmn_gemmk1_global_desc
=
transform_dynamic_tensor_descriptor
(
in_gemmk_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
/
GemmKPack
,
GemmKPack
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK
0
,
GemmKPack
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -124,7 +125,8 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert
(
GemmM
==
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
));
assert
(
GemmN
==
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
));
const
auto
GemmK0
=
wei_gemmk0_gemmm_gemmk1_global_desc
.
GetLength
(
I0
);
assert
(
GemmK0
==
in_gemmk0_gemmn_gemmk1_global_desc
.
GetLength
(
I0
));
assert
(
GemmK0
==
wei_gemmk0_gemmm_gemmk1_global_desc
.
GetLength
(
I0
));
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK0
%
GemmKPerBlock
==
0
);
...
...
@@ -156,6 +158,7 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
constexpr
auto
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
#if 0
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
...
...
@@ -179,6 +182,31 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
#else
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
constexpr
auto
in_gemmk0_gemmn_gemmk1_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr
auto
out_m0_m1_m2_n_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
#endif
return
make_tuple
(
wei_gemmk0_gemmm_gemmk1_global_desc
,
in_gemmk0_gemmn_gemmk1_global_desc
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
78f6584a
...
...
@@ -141,9 +141,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
{
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_KPack
>
{},
Number
<
BBlockTransferDstScalarPerVector_KPack
>
{},
Number
<
KPack
>
{});
constexpr
auto
max_lds_align
=
KPack
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
...
...
@@ -192,6 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const
auto
K0
=
a_k0_m_k1_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k0_m_k1_global_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_global_desc
.
GetLength
(
I1
);
const
auto
K1
=
b_k0_n_k1_global_desc
.
GetLength
(
I2
);
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -205,9 +204,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_KPack
>
{},
Number
<
BBlockTransferDstScalarPerVector_KPack
>
{},
Number
<
KPack
>
{});
constexpr
auto
max_lds_align
=
KPack
;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment