Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ccf94638
Commit
ccf94638
authored
Dec 19, 2024
by
Mateusz Ozga
Browse files
Pass 4d sequence and convert to 3d
parent
860433ea
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
114 additions
and
81 deletions
+114
-81
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+6
-6
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+6
-6
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
..._weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
+16
-16
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+80
-47
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
...d_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
+6
-6
No files found.
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
ccf94638
...
@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance =
...
@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance =
16
,
// NPerXdl
16
,
// NPerXdl
1
,
// MXdlPerWave
1
,
// MXdlPerWave
1
,
// NXdlPerWave
1
,
// NXdlPerWave
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
1
,
// ABlockTransferSrcScalarPerVector
4
,
// ABlockTransferDstScalarPerVector_K1
4
,
// ABlockTransferDstScalarPerVector_K1
false
,
// ABlockLdsAddExtraM
false
,
// ABlockLdsAddExtraM
S
<
4
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
1
,
// BBlockTransferSrcScalarPerVector
4
,
// BBlockTransferDstScalarPerVector_K1
4
,
// BBlockTransferDstScalarPerVector_K1
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
ccf94638
...
@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance =
...
@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance =
32
,
// NPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
2
,
// MXdlPerWave
2
,
// NXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferSrcScalarPerVector
2
,
// ABlockTransferDstScalarPerVector_K1
2
,
// ABlockTransferDstScalarPerVector_K1
false
,
// ABlockLdsAddExtraM
false
,
// ABlockLdsAddExtraM
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferSrcScalarPerVector
2
,
// BBlockTransferDstScalarPerVector_K1
2
,
// BBlockTransferDstScalarPerVector_K1
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp
View file @
ccf94638
...
@@ -49,16 +49,16 @@ using DeviceConvBwdWeightInstance =
...
@@ -49,16 +49,16 @@ using DeviceConvBwdWeightInstance =
16
,
// NPerXdl
16
,
// NPerXdl
1
,
// MXdlPerWave
1
,
// MXdlPerWave
1
,
// NXdlPerWave
1
,
// NXdlPerWave
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
1
,
// ABlockTransferSrcScalarPerVector
4
,
// ABlockTrans
ferDstScalarPerVector_K1
4
,
// ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
ferDstScalarPerVector_K1
false
,
// ABlockLdsAddExtraM
false
,
// ABlockLdsAddExtraM
S
<
4
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
1
,
// BBlockTransferSrcScalarPerVector
4
,
// BBlockTransferDstScalarPerVector_K1
4
,
// BBlockTransferDstScalarPerVector_K1
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
ccf94638
...
@@ -315,14 +315,43 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -315,14 +315,43 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
batch
);
batch
);
}
}
template
<
typename
SeqType
>
constexpr
static
auto
ShuffleSequenceAndTransformFrom4DTo3D
()
noexcept
(
noexcept
(
SeqType
{}.
Size
()
==
4
))
->
decltype
(
auto
)
{
// Remove first element and,
// Convert 4d->3d sequence.
constexpr
auto
_I0
=
SeqType
{}.
At
(
I1
);
constexpr
auto
_I1
=
SeqType
{}.
At
(
I2
);
constexpr
auto
_I2
=
SeqType
{}.
At
(
I0
);
constexpr
auto
_Seq
=
S
<
_I0
,
_I1
,
_I2
>
();
return
_Seq
;
}
template
<
typename
SeqType
>
constexpr
static
auto
TransformSequenceFrom4DTo3dAndReduceByOne
()
noexcept
(
noexcept
(
SeqType
{}.
Size
()
==
4
))
->
decltype
(
auto
)
{
// Skip first element and
// Convert 4d->3d and take away one from seq.
constexpr
index_t
one
=
1
;
constexpr
auto
_I0
=
SeqType
{}.
At
(
I1
)
-
one
;
constexpr
auto
_I1
=
SeqType
{}.
At
(
I2
)
-
one
;
constexpr
auto
_I2
=
SeqType
{}.
At
(
I3
)
-
one
;
constexpr
auto
_Seq
=
S
<
_I0
,
_I1
,
_I2
>
();
return
_Seq
;
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
GridwiseGemm_xdl_cshuffle_v3
<
tensor_layout
::
gemm
::
RowMajor
,
tensor_layout
::
gemm
::
RowMajor
,
tensor_layout
::
gemm
::
ColumnMajor
,
tensor_layout
::
gemm
::
ColumnMajor
,
tensor_layout
::
gemm
::
RowMajor
,
tensor_layout
::
gemm
::
RowMajor
,
ADataType
,
ADataType
,
...
@@ -344,17 +373,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -344,17 +373,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
NPerXdl
,
NPerXdl
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
decltype
(
ShuffleSequenceAndTransformFrom4DTo3D
<
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterLengths_K0_M_K1
>
()),
ABlockTransferSrcAccessOrder
,
decltype
(
TransformSequenceFrom4DTo3dAndReduceByOne
<
ABlockTransferThreadClusterArrangeOrder
>
()),
decltype
(
TransformSequenceFrom4DTo3dAndReduceByOne
<
ABlockTransferSrcAccessOrder
>
()),
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
false
,
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
decltype
(
ShuffleSequenceAndTransformFrom4DTo3D
<
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterLengths_K0_N_K1
>
()),
BBlockTransferSrcAccessOrder
,
decltype
(
TransformSequenceFrom4DTo3dAndReduceByOne
<
BBlockTransferThreadClusterArrangeOrder
>
()),
decltype
(
TransformSequenceFrom4DTo3dAndReduceByOne
<
BBlockTransferSrcAccessOrder
>
()),
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
View file @
ccf94638
...
@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
16
,
// NPerXdl
16
,
// NPerXdl
1
,
// MXdlPerWave
1
,
// MXdlPerWave
1
,
// NXdlPerWave
1
,
// NXdlPerWave
S
<
4
,
16
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
4
,
16
,
4
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
2
,
0
,
1
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// ABlockTransferSrcAccessOrder
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcVectorDim
1
,
// ABlockTransferSrcScalarPerVector
1
,
// ABlockTransferSrcScalarPerVector
4
,
// ABlockTransferDstScalarPerVector_K1
4
,
// ABlockTransferDstScalarPerVector_K1
false
,
// ABlockLdsAddExtraM
false
,
// ABlockLdsAddExtraM
S
<
4
,
16
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
4
,
16
,
4
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
2
,
0
,
1
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
0
,
3
,
1
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
S
<
0
,
2
,
1
,
3
>
,
// BBlockTransferSrcAccessOrder
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcVectorDim
1
,
// BBlockTransferSrcScalarPerVector
1
,
// BBlockTransferSrcScalarPerVector
4
,
// BBlockTransferDstScalarPerVector_K1
4
,
// BBlockTransferDstScalarPerVector_K1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment