Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
840cba8e
Commit
840cba8e
authored
Sep 03, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/moe
parents
bf8e6de7
73b67f29
Changes
24
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
802 additions
and
79 deletions
+802
-79
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+373
-39
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+30
-14
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+33
-1
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+52
-2
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+8
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
+40
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+30
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
..._instance/gpu/grouped_convolution_backward_weight_xdl.inc
+46
-0
library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
...ary/utility/convolution_host_tensor_descriptor_helper.hpp
+41
-1
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+2
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
...t_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
...t_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
+2
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
...wo_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
+41
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
840cba8e
...
...
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return
false
;
if
constexpr
(
!
((
NDimSpatial
==
1
&&
(
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
2
&&
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
3
&&
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
840cba8e
...
...
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
840cba8e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
840cba8e
...
...
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return
false
;
}
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
840cba8e
...
...
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
840cba8e
...
...
@@ -925,7 +925,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
false
;
}
}
if
constexpr
(
!
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
...
...
@@ -941,7 +941,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return
false
;
}
if
constexpr
(
!
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
...
...
@@ -960,7 +960,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
()
&&
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
()
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
View file @
840cba8e
...
...
@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return
false
;
}
}
if
constexpr
(
!
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
840cba8e
...
...
@@ -12,7 +12,7 @@ namespace device {
// 1d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NWG
K
_GKXC_NWG
C
()
constexpr
bool
is_NWG
C
_GKXC_NWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNW
K
_GKXC_GNW
C
()
constexpr
bool
is_GNW
C
_GKXC_GNW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWG
K
_GKYXC_NHWG
C
()
constexpr
bool
is_NHWG
C
_GKYXC_NHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
...
...
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNHW
K
_GKYXC_GNHW
C
()
constexpr
bool
is_GNHW
C
_GKYXC_GNHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNHWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCHW_GKYXC_NGKHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKHW
>
;
}
// 3d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NDHWG
K
_GKZYXC_NDHWG
C
()
constexpr
bool
is_NDHWG
C
_GKZYXC_NDHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNDHW
K
_GKZYXC_GNDHW
C
()
constexpr
bool
is_GNDHW
C
_GKZYXC_GNDHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGK_GKSpatial_NSpatialGC
()
constexpr
bool
is_NGCDHW_GKZYXC_NGKDHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCDHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKDHW
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGC_GKSpatial_NSpatialGK
()
{
return
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNSpatial
K
_GKSpatial_GNSpatial
C
()
constexpr
bool
is_GNSpatial
C
_GKSpatial_GNSpatial
K
()
{
return
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGC"
;
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct
NGCW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCW"
;
};
struct
NGCHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCHW"
;
};
struct
NGCDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCDHW"
;
};
// input tensor
// strided layout
struct
G_NW_C
:
public
BaseTensorLayout
...
...
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGK"
;
};
struct
NGKW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKW"
;
};
struct
NGKHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKHW"
;
};
struct
NGKDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKDHW"
;
};
// output tensor
// strided layout
struct
G_NW_K
:
public
BaseTensorLayout
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
View file @
840cba8e
...
...
@@ -41,6 +41,55 @@ __global__ void
elementwise_op
);
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InAGridDescTuple
,
typename
InBGridDescTuple
,
typename
OutAGridDescTuple
,
typename
OutBGridDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
Block2TileMapA
,
typename
Block2TileMapB
,
typename
ElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_elementwise_dual
(
const
InBGridDescTuple
in_grid_desc_tuple_a
,
const
InBGridDescTuple
in_grid_desc_tuple_b
,
const
OutAGridDescTuple
out_grid_desc_tuple_a
,
const
OutBGridDescTuple
out_grid_desc_tuple_b
,
const
InDataTypePointerTuple
p_in_global_tuple_a
,
const
InDataTypePointerTuple
p_in_global_tuple_b
,
const
OutDataTypePointerTuple
p_out_global_tuple_a
,
const
OutDataTypePointerTuple
p_out_global_tuple_b
,
const
Block2TileMapA
block_2_tile_map_a
,
const
Block2TileMapB
block_2_tile_map_b
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
a_grid_size
)
{
if
(
get_block_1d_id
()
<
a_grid_size
)
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_a
,
out_grid_desc_tuple_a
,
p_in_global_tuple_a
,
p_out_global_tuple_a
,
block_2_tile_map_a
,
elementwise_op
,
get_block_1d_id
());
}
else
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_b
,
out_grid_desc_tuple_b
,
p_in_global_tuple_b
,
p_out_global_tuple_b
,
block_2_tile_map_b
,
elementwise_op
,
get_block_1d_id
()
-
a_grid_size
);
}
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InGridDescTuple
,
typename
OutGridDescTuple
,
...
...
@@ -133,7 +182,8 @@ struct GridwiseElementwise
const
InDataTypePointerTuple
&
p_in_global_tuple
,
const
OutDataTypePointerTuple
&
p_out_global_tuple
,
const
Block2TileMap
&
block_2_tile_map
,
const
ElementwiseOperation
&
elementwise_op
)
const
ElementwiseOperation
&
elementwise_op
,
const
index_t
block_id
=
get_block_1d_id
())
{
constexpr
auto
src_datas
=
generate_tuple
(
...
...
@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number
<
NumOutput
>
{});
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_
block_
1d_id
()
));
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_
id
));
const
index_t
m0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
M0PerBlock
);
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
840cba8e
...
...
@@ -74,6 +74,10 @@ using GNWK = ck::tensor_layout::convolution::GNWK;
using
GNHWK
=
ck
::
tensor_layout
::
convolution
::
GNHWK
;
using
GNDHWK
=
ck
::
tensor_layout
::
convolution
::
GNDHWK
;
using
NGKW
=
ck
::
tensor_layout
::
convolution
::
NGKW
;
using
NGKHW
=
ck
::
tensor_layout
::
convolution
::
NGKHW
;
using
NGKDHW
=
ck
::
tensor_layout
::
convolution
::
NGKDHW
;
//
using
NWGC
=
ck
::
tensor_layout
::
convolution
::
NWGC
;
using
NHWGC
=
ck
::
tensor_layout
::
convolution
::
NHWGC
;
...
...
@@ -87,6 +91,10 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
NGCW
=
ck
::
tensor_layout
::
convolution
::
NGCW
;
using
NGCHW
=
ck
::
tensor_layout
::
convolution
::
NGCHW
;
using
NGCDHW
=
ck
::
tensor_layout
::
convolution
::
NGCDHW
;
//
using
G_K
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_Tuple
=
ck
::
Tuple
<
G_K
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp
View file @
840cba8e
...
...
@@ -56,6 +56,46 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std
// clang-format on
>
;
// NGCHW requires transpose, we use vector loads and stores params for them
template
<
ck
::
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
ConvolutionBackwardWeightSpecialization
ConvSpec
,
BlockGemmPipelineScheduler
Scheduler
,
BlockGemmPipelineVersion
PipelineVersion
>
using
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
16
,
16
,
32
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
1
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
1
,
F16
,
F16
,
1
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
1
,
2
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
1
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
1
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
1
,
4
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
1
,
8
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
32
,
32
,
8
,
32
,
32
,
1
,
1
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
2
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
2
,
F16
,
F16
,
2
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
64
,
32
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
32
,
128
,
32
,
8
,
32
,
32
,
1
,
4
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
4
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
64
,
32
,
32
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
S
<
4
,
8
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
4
,
4
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
8
>
,
1
,
Scheduler
,
PipelineVersion
,
4
,
F16
,
F16
,
4
,
1
>
,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
ELayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
64
,
128
,
32
,
32
,
8
,
32
,
32
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
S
<
4
,
4
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
8
,
false
,
1
,
1
,
S
<
1
,
8
,
1
,
4
>
,
1
,
Scheduler
,
PipelineVersion
,
8
,
F16
,
F16
,
8
,
1
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
840cba8e
...
...
@@ -367,6 +367,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
op_ptrs
);
}
#endif
}
if
constexpr
(
is_same_v
<
InLayout
,
NGCHW
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NGKHW
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
op_ptrs
);
}
#endif
}
}
...
...
@@ -447,6 +462,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances
(
op_ptrs
);
}
#endif
}
if
constexpr
(
is_same_v
<
InLayout
,
NGCDHW
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
NGKDHW
>
)
{
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
&&
is_same_v
<
ComputeTypeA
,
half_t
>
&&
is_same_v
<
ComputeTypeB
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
op_ptrs
);
}
#endif
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc
View file @
840cba8e
...
...
@@ -137,6 +137,29 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
...
...
@@ -240,6 +263,29 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
...
...
library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp
View file @
840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -46,6 +46,21 @@ std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKW
>
)
{
return
{
1
,
0
,
2
,
3
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKHW
>
)
{
return
{
1
,
0
,
2
,
3
,
4
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGCDHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
NGKDHW
>
)
{
return
{
1
,
0
,
2
,
3
,
4
,
5
};
}
else
if
constexpr
(
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GKCYX
>
||
ck
::
is_same_v
<
OldLayout
,
ck
::
tensor_layout
::
convolution
::
GNKHW
>
)
...
...
@@ -132,6 +147,18 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NGCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCHW
>
||
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
GNCDHW
>
)
...
...
@@ -314,6 +341,19 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKW
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKHW
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NGKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNHWK
>
||
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
840cba8e
...
...
@@ -8,6 +8,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
0 → 100644
View file @
840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp
0 → 100644
View file @
840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
<
2
,
NGCHW
,
GKYXC
,
NGKHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
View file @
840cba8e
...
...
@@ -8,6 +8,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp
0 → 100644
View file @
840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment