Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
69ccb257
Commit
69ccb257
authored
Aug 15, 2023
by
Bartlomiej Kocot
Browse files
Add instances for small K and small C
parent
11e39518
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
78 additions
and
55 deletions
+78
-55
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+7
-6
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+44
-46
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
...nv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
+18
-3
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp
.../grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp
+9
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
69ccb257
...
...
@@ -266,12 +266,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
AK
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
BK
=
b_grid_desc_n_k
.
GetLength
(
I1
);
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
AK
==
BK
))
{
return
false
;
}
...
...
@@ -289,13 +290,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check tile size
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
A
K
%
KPerBlock
==
0
))
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
const
auto
num_k_loop
=
A
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
69ccb257
...
...
@@ -364,20 +364,20 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmmraw_grid_desc
.
GetLength
(
I1
))),
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
...
...
@@ -457,20 +457,20 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmmraw_grid_desc
.
GetLength
(
I1
))),
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmak0_gemmmraw_gemmak1_grid_desc
,
make_tuple
(
AK0
,
GemmMPerBlock
,
AK1
),
Sequence
<
false
,
DoPadGemmM
,
false
>
{});
return
out_gemmak0_gemmm_gemmak1_grid_desc
;
}
else
...
...
@@ -614,21 +614,20 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
2
,
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmnraw_grid_desc
,
const
auto
wei_gemmk_gemmn_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
...
...
@@ -691,22 +690,21 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
C
)),
const
auto
wei_gemmk_gemm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
wei_gemmk_gemm_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
,
make_tuple
(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
.
GetLength
(
I0
),
GemmNPerBlock
,
BK1
),
Sequence
<
false
,
DoPadGemmN
,
false
>
{});
return
wei_gemmbk0_gemmn_gemmbk1_grid_desc
;
return
wei_gemmbk0_gemm_gemmbk1_grid_desc
;
}
else
{
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
View file @
69ccb257
...
...
@@ -47,6 +47,12 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
...
...
@@ -61,7 +67,6 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// TODO: After enable, add instance for small conv.K and conv.C
// clang-format on
>
;
...
...
@@ -78,6 +83,12 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple<
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
...
...
@@ -92,7 +103,6 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple<
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// TODO: After enable, add instance for small conv.K and conv.C
// clang-format on
>
;
...
...
@@ -110,6 +120,12 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
>
,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
1
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
,
...
...
@@ -124,7 +140,6 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
>
,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
true
,
true
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
// TODO: After enable, add instance for small conv.K and conv.C
// clang-format on
>
;
...
...
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp
View file @
69ccb257
...
...
@@ -87,6 +87,9 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
8
,
8
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
template
Run
<
2
>();
}
...
...
@@ -99,5 +102,11 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
template
Run
<
3
>();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment