Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2edac9f1
Commit
2edac9f1
authored
May 30, 2024
by
Bartlomiej Kocot
Browse files
Integrate universal gemm with conv bwd data
parent
34f3dfdd
Changes
50
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1808 additions
and
236 deletions
+1808
-236
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+613
-89
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+84
-21
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_comp_instance.hpp
...d_data/device_grouped_conv_bwd_data_xdl_comp_instance.hpp
+105
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_mem_instance.hpp
...wd_data/device_grouped_conv_bwd_data_xdl_mem_instance.hpp
+210
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+69
-15
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl_comp.inc
...stance/gpu/grouped_convolution_backward_data_xdl_comp.inc
+199
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl_mem_inter.inc
...e/gpu/grouped_convolution_backward_data_xdl_mem_inter.inc
+199
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl_mem_intra.inc
...e/gpu/grouped_convolution_backward_data_xdl_mem_intra.inc
+13
-13
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt
...ation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt
+20
-6
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/comp/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_comp_instance.cpp
...v2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_comp_instance.cpp
+15
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/comp/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_comp_instance.cpp
...nv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_comp_instance.cpp
+15
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/comp/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_comp_instance.cpp
...nv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_comp_instance.cpp
+15
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/comp/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
...v2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
+15
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/comp/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
...nv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
+15
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/comp/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
...nv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
+15
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_inter_instance.cpp
...wd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_inter_instance.cpp
+51
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_intra_instance.cpp
...wd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_intra_instance.cpp
+51
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_inter_instance.cpp
...bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_inter_instance.cpp
+51
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_intra_instance.cpp
...bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_intra_instance.cpp
+51
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
2edac9f1
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
2edac9f1
...
@@ -674,7 +674,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -674,7 +674,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
clear_workspace
();
clear_workspace
();
};
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
ave_time
+
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
stream_config
,
run_flush_cache
,
run_flush_cache
,
kernel
,
kernel
,
...
@@ -690,7 +690,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -690,7 +690,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
else
else
{
{
ave_time
=
launch_and_time_kernel_with_preprocess
(
ave_time
+
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
stream_config
,
clear_workspace
,
clear_workspace
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
2edac9f1
...
@@ -1123,7 +1123,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1123,7 +1123,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
template
<
typename
CGridDesc
>
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
...
@@ -1141,26 +1141,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1141,26 +1141,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template
<
bool
HasMainKBlockLoop
,
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
void
*
p_shared
,
const
Problem
&
problem
)
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
)
{
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1508,12 +1504,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1508,12 +1504,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared
,
void
*
p_shared_1
,
const
Problem
&
problem
)
const
Problem
&
problem
)
{
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
...
@@ -1521,11 +1516,42 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1521,11 +1516,42 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
Run
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
problem
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
}
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1879,6 +1905,43 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1879,6 +1905,43 @@ struct GridwiseGemm_xdl_cshuffle_v3
});
});
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
Run_2Lds
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_0
,
p_shared_1
,
problem
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
}
};
};
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_comp_instance.hpp
0 → 100644
View file @
2edac9f1
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
mem_
instance.hpp
View file @
2edac9f1
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -11,7 +11,9 @@
...
@@ -11,7 +11,9 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_USE_XDL
#ifdef CK_USE_XDL
#include "grouped_convolution_backward_data_xdl.inc"
#include "grouped_convolution_backward_data_xdl_comp.inc"
#include "grouped_convolution_backward_data_xdl_mem_inter.inc"
#include "grouped_convolution_backward_data_xdl_mem_intra.inc"
#endif
#endif
#ifdef CK_USE_WMMA
#ifdef CK_USE_WMMA
#include "grouped_convolution_backward_data_wmma.inc"
#include "grouped_convolution_backward_data_wmma.inc"
...
@@ -79,7 +81,12 @@ struct DeviceOperationInstanceFactory<
...
@@ -79,7 +81,12 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
ComputeTypeB
,
F16
>
)
is_same_v
<
ComputeTypeB
,
F16
>
)
{
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_mem_inter_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
...
@@ -87,7 +94,12 @@ struct DeviceOperationInstanceFactory<
...
@@ -87,7 +94,12 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
ComputeTypeB
,
F32
>
)
is_same_v
<
ComputeTypeB
,
F32
>
)
{
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_mem_inter_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
...
@@ -95,7 +107,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -95,7 +107,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
ComputeTypeB
,
BF16
>
)
is_same_v
<
ComputeTypeB
,
BF16
>
)
{
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances
(
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -108,7 +124,12 @@ struct DeviceOperationInstanceFactory<
...
@@ -108,7 +124,12 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
ComputeTypeB
,
F16
>
)
is_same_v
<
ComputeTypeB
,
F16
>
)
{
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_mem_inter_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
...
@@ -116,7 +137,12 @@ struct DeviceOperationInstanceFactory<
...
@@ -116,7 +137,12 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
ComputeTypeB
,
F32
>
)
is_same_v
<
ComputeTypeB
,
F32
>
)
{
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_mem_inter_instances
(
op_ptrs
);
}
}
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
...
@@ -124,7 +150,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -124,7 +150,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
ComputeTypeB
,
BF16
>
)
is_same_v
<
ComputeTypeB
,
BF16
>
)
{
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances
(
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -140,7 +170,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -140,7 +170,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
ComputeTypeB
,
F16
>
)
is_same_v
<
ComputeTypeB
,
F16
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances
(
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -149,7 +183,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -149,7 +183,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
ComputeTypeB
,
F32
>
)
is_same_v
<
ComputeTypeB
,
F32
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances
(
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -158,7 +196,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -158,7 +196,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
ComputeTypeB
,
BF16
>
)
is_same_v
<
ComputeTypeB
,
BF16
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances
(
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -171,7 +213,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -171,7 +213,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
F16
>
&&
is_same_v
<
ComputeTypeB
,
F16
>
)
is_same_v
<
ComputeTypeB
,
F16
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances
(
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -180,7 +226,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -180,7 +226,7 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
bf8_t
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
ComputeTypeA
,
bf8_t
>
&&
is_same_v
<
ComputeTypeB
,
f8_t
>
)
is_same_v
<
ComputeTypeB
,
f8_t
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances
(
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_
mem_
instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -189,7 +235,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -189,7 +235,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
ComputeTypeA
,
F32
>
&&
is_same_v
<
ComputeTypeB
,
F32
>
)
is_same_v
<
ComputeTypeB
,
F32
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances
(
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
@@ -198,7 +248,11 @@ struct DeviceOperationInstanceFactory<
...
@@ -198,7 +248,11 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
&&
is_same_v
<
ComputeTypeA
,
BF16
>
&&
is_same_v
<
ComputeTypeB
,
BF16
>
)
is_same_v
<
ComputeTypeB
,
BF16
>
)
{
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances
(
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_mem_intra_instances
(
op_ptrs
);
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_mem_inter_instances
(
op_ptrs
);
op_ptrs
);
}
}
#endif
#endif
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl_comp.inc
0 → 100644
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
Empty_Tuple
,
NHWGC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
Empty_Tuple
,
NHWGC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
Empty_Tuple
,
NHWGC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl_mem_inter.inc
0 → 100644
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
Empty_Tuple
,
NHWGC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
Empty_Tuple
,
NHWGC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
Empty_Tuple
,
NHWGC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc
→
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl
_mem_intra
.inc
View file @
2edac9f1
...
@@ -9,7 +9,7 @@ namespace device {
...
@@ -9,7 +9,7 @@ namespace device {
namespace
instance
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
...
@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
...
@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
...
@@ -40,7 +40,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
...
@@ -40,7 +40,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
...
@@ -56,7 +56,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
...
@@ -56,7 +56,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
...
@@ -71,7 +71,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
...
@@ -71,7 +71,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
...
@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
...
@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
...
@@ -103,7 +103,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
...
@@ -103,7 +103,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
// conv3d backward data
// conv3d backward data
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GNDHWK
,
GKZYXC
,
GKZYXC
,
...
@@ -118,7 +118,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
...
@@ -118,7 +118,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GNDHWK
,
GKZYXC
,
GKZYXC
,
...
@@ -133,7 +133,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
...
@@ -133,7 +133,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GNDHWK
,
GKZYXC
,
GKZYXC
,
...
@@ -148,7 +148,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
...
@@ -148,7 +148,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
NDHWGK
,
GKZYXC
,
GKZYXC
,
...
@@ -163,7 +163,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
...
@@ -163,7 +163,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_FP32
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
NDHWGK
,
GKZYXC
,
GKZYXC
,
...
@@ -178,7 +178,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
...
@@ -178,7 +178,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_
mem_intra_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
NDHWGK
,
GKZYXC
,
GKZYXC
,
...
@@ -193,7 +193,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
...
@@ -193,7 +193,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances
(
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_
mem_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
NDHWGK
,
GKZYXC
,
GKZYXC
,
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt
View file @
2edac9f1
# ONLY XDL_AND_WMMA_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library
(
add_instance_library
(
device_grouped_conv2d_bwd_data_instance
device_grouped_conv2d_bwd_data_instance
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
xdl/comp/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_comp_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
xdl/comp/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_comp_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
xdl/comp/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_comp_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/comp/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/comp/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/comp/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/
comp/
device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_
comp_
instance.cpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
comp_
instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_
comp_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_bf16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_bf16_
comp_
instances
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
GNHWC
,
GNHWC
,
ConvBwdDataDefault
>
{});
ConvBwdDataDefault
>
{});
// 2. Filter1x1Stride1Pad0
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_bf16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_bf16_
comp_
instances
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
GNHWC
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/
comp/
device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_
comp_
instance.cpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
comp_
instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_
comp_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f16_
comp_
instances
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
GNHWC
,
GNHWC
,
ConvBwdDataDefault
>
{});
ConvBwdDataDefault
>
{});
// 2. Filter1x1Stride1Pad0
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f16_
comp_
instances
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
GNHWC
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/
comp/
device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_
comp_
instance.cpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
comp_
instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_
comp_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f32_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f32_
comp_
instances
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
GNHWC
,
GNHWC
,
ConvBwdDataDefault
>
{});
ConvBwdDataDefault
>
{});
// 2. Filter1x1Stride1Pad0
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f32_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f32_
comp_
instances
<
2
,
GNHWK
,
GNHWK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
GNHWC
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/
comp/
device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_
comp_
instance.cpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
comp_
instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_
comp_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_bf16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_bf16_
comp_
instances
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
NHWGC
,
NHWGC
,
ConvBwdDataDefault
>
{});
ConvBwdDataDefault
>
{});
// 2. Filter1x1Stride1Pad0
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_bf16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_bf16_
comp_
instances
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
NHWGC
,
NHWGC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/
comp/
device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_
comp_
instance.cpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
comp_
instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_
comp_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f16_
comp_
instances
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
NHWGC
,
NHWGC
,
ConvBwdDataDefault
>
{});
ConvBwdDataDefault
>
{});
// 2. Filter1x1Stride1Pad0
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f16_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f16_
comp_
instances
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
NHWGC
,
NHWGC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/
comp/
device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_
comp_
instance.cpp
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_
comp_
instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_
comp_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
...
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f32_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f32_
comp_
instances
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
NHWGC
,
NHWGC
,
ConvBwdDataDefault
>
{});
ConvBwdDataDefault
>
{});
// 2. Filter1x1Stride1Pad0
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_bwd_data_xdl_f32_instances
<
2
,
device_grouped_conv_bwd_data_xdl_f32_
comp_
instances
<
2
,
NHWGK
,
NHWGK
,
GKYXC
,
GKYXC
,
Empty_Tuple
,
Empty_Tuple
,
NHWGC
,
NHWGC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_inter_instance.cpp
0 → 100644
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_mem_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_bf16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataDefault
,
Interwave
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_bf16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_mem_intra_instance.cpp
0 → 100644
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_mem_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_bf16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataDefault
,
Intrawave
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_bf16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
,
Intrawave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_inter_instance.cpp
0 → 100644
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_mem_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_mem_inter_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_f16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataDefault
,
Interwave
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_f16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/mem/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_mem_intra_instance.cpp
0 → 100644
View file @
2edac9f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_mem_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_f16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataDefault
,
Intrawave
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_data_xdl_f16_mem_instances
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
ConvBwdDataFilter1x1Stride1Pad0
,
Intrawave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment