Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4d914af3
Commit
4d914af3
authored
Oct 31, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
223a2abe
4b798833
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
714 additions
and
36 deletions
+714
-36
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
...weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
...weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
...o_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
...o_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
...o_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
...o_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
...eight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp
+15
-13
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
...eight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
+48
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...v3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt
...instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt
+8
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+55
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
..._fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+55
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
..._fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+55
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
...fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
+54
-0
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+6
-4
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
+31
-7
profiler/src/profile_gemm_multiply_multiply.cpp
profiler/src/profile_gemm_multiply_multiply.cpp
+9
-1
profiler/src/profile_gemm_universal.cpp
profiler/src/profile_gemm_universal.cpp
+19
-4
profiler/src/profile_grouped_conv_bwd_weight.cpp
profiler/src/profile_grouped_conv_bwd_weight.cpp
+23
-2
python/ck4inductor/grouped_conv_fwd/gen_instances.py
python/ck4inductor/grouped_conv_fwd/gen_instances.py
+167
-0
No files found.
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_
f32_bf16_
instance.cpp
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_
f32_bf16_
instance.cpp
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp"
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances
<
3
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_
f32_bf16_
instance.cpp
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...
...
@@ -24,19 +24,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ConvBwdWeightDefault
>
{});
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...
...
@@ -10,13 +10,13 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_
f32_bf16_
instances
(
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
F32
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt
0 → 100644
View file @
4d914af3
# ONLY XDL_KERNELS
set
(
GROUPED_CONV3D_FWD_DYNAMIC_OP
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_dynamic_op_instance
${
GROUPED_CONV3D_FWD_DYNAMIC_OP
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
BF16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F16
,
F16
,
ck
::
Tuple
<>
,
F16
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F32
,
F32
,
ck
::
Tuple
<>
,
F32
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
int8_t
,
int8_t
,
ck
::
Tuple
<>
,
int8_t
,
PassThrough
,
PassThrough
,
DynamicUnaryOp
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_dynamic_op_int8_instances
<
3
,
NDHWGC
,
GKZYXC
,
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
4d914af3
...
...
@@ -271,10 +271,12 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_INT8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
EDataType
,
f8_t
>
)
if
constexpr
((
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
EDataType
,
f8_t
>
)
||
(
is_same_v
<
ADataType
,
int8_t
>
||
is_same_v
<
BDataType
,
int8_t
>
||
is_same_v
<
EDataType
,
int8_t
>
))
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
...
...
@@ -286,7 +288,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
|| defined CK_ENABLE_INT8
}
#endif
...
...
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
View file @
4d914af3
...
...
@@ -102,11 +102,22 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
Tensor
<
IndexDataType
>
out_indices_n_c_do_ho_wo_device
(
f_host_tensor_descriptor
(
N
,
C
,
Do
,
Ho
,
Wo
));
constexpr
int
inDataRangeTensor1
{
1
};
constexpr
int
inDataRangeTensor2
{
5
};
constexpr
double
inDataRangeTensor3
{
0.5
};
switch
(
in_params
.
init_method
)
{
case
0
:
in_n_c_di_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{});
break
;
case
1
:
in_n_c_di_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
break
;
default:
in_n_c_di_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
0.5
,
0.5
});
case
0
:
in_n_c_di_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
inDataRangeTensor1
});
break
;
case
1
:
in_n_c_di_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
inDataRangeTensor2
,
inDataRangeTensor2
});
break
;
default:
in_n_c_di_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
inDataRangeTensor3
,
inDataRangeTensor3
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_di_hi_wi
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -229,12 +240,25 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
{
out_device_buf
.
FromDevice
(
out_n_c_do_ho_wo_device
.
mData
.
data
());
auto
tolerance
=
1e-3
;
bool
pass
=
ck
::
utils
::
check_err
(
out_n_c_do_ho_wo_device
.
mData
,
auto
absolute_error_threshold
=
1.0
;
switch
(
in_params
.
init_method
)
{
case
0
:
absolute_error_threshold
=
static_cast
<
double
>
(
inDataRangeTensor1
);
break
;
case
1
:
absolute_error_threshold
=
static_cast
<
double
>
(
inDataRangeTensor2
);
break
;
default:
absolute_error_threshold
=
inDataRangeTensor3
;
}
absolute_error_threshold
=
ck
::
utils
::
get_absolute_threshold
<
ComputeDataType
,
OutDataType
>
(
absolute_error_threshold
);
auto
relative_error_threshold
=
ck
::
utils
::
get_relative_threshold
<
ComputeDataType
,
OutDataType
>
();
bool
pass
=
ck
::
utils
::
check_err
(
out_n_c_do_ho_wo_device
.
mData
,
out_n_c_do_ho_wo_host
.
mData
,
"Error: Incorrect results"
,
tolerance
,
tolerance
);
relative_error_threshold
,
absolute_error_threshold
);
if
constexpr
(
OutputIndex
)
{
...
...
profiler/src/profile_gemm_multiply_multiply.cpp
View file @
4d914af3
...
...
@@ -27,6 +27,7 @@ enum struct GemmDataType
F16_F8_F16
,
// 5
F16_F16_F16_F8
,
// 6
F8_F8_BF16
,
// 7
INT8_INT8_BF16
,
// 8
};
#define OP_NAME "gemm_multiply_multiply"
...
...
@@ -39,7 +40,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf
(
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
);
printf
(
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8)
\n
"
);
"comp f8
; 8: int8->bf16
)
\n
"
);
printf
(
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
...
...
@@ -89,6 +90,8 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using
F32
=
float
;
using
BF16
=
ck
::
bhalf_t
;
using
F8
=
ck
::
f8_t
;
using
I8
=
int8_t
;
using
I32
=
int
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -162,6 +165,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
GemmDataType
::
INT8_INT8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
return
profile
(
I8
{},
I8
{},
I8
{},
I32
{},
F32
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{},
Col
{},
Row
{});
}
else
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
profiler/src/profile_gemm_universal.cpp
View file @
4d914af3
...
...
@@ -57,6 +57,25 @@ int profile_gemm_universal(int argc, char* argv[])
exit
(
1
);
}
int
M
;
int
N
;
int
StrideA
;
int
StrideB
;
// Analyze the unsupported matrix shapes, switch the M and N number
if
(
std
::
stoi
(
argv
[
9
])
%
8
!=
0
&&
std
::
stoi
(
argv
[
8
])
%
8
==
0
)
{
M
=
std
::
stoi
(
argv
[
9
]);
StrideA
=
std
::
stoi
(
argv
[
12
]);
N
=
std
::
stoi
(
argv
[
8
]);
StrideB
=
std
::
stoi
(
argv
[
11
]);
}
else
{
M
=
std
::
stoi
(
argv
[
8
]);
StrideA
=
std
::
stoi
(
argv
[
11
]);
N
=
std
::
stoi
(
argv
[
9
]);
StrideB
=
std
::
stoi
(
argv
[
12
]);
}
const
auto
data_type
=
static_cast
<
GemmDataType
>
(
std
::
stoi
(
argv
[
2
]));
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
3
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
4
]);
...
...
@@ -64,12 +83,8 @@ int profile_gemm_universal(int argc, char* argv[])
const
bool
do_log
=
std
::
stoi
(
argv
[
6
]);
const
bool
time_kernel
=
std
::
stoi
(
argv
[
7
]);
const
int
M
=
std
::
stoi
(
argv
[
8
]);
const
int
N
=
std
::
stoi
(
argv
[
9
]);
const
int
K
=
std
::
stoi
(
argv
[
10
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
11
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
14
]);
...
...
profiler/src/profile_grouped_conv_bwd_weight.cpp
View file @
4d914af3
...
...
@@ -25,7 +25,8 @@ enum struct ConvDataType
F16_F16_F16
,
// 1
BF16_F32_BF16
,
// 2
F16_F16_F16_BF8_F8
,
// 3
I8_I8_I8
// 4
I8_I8_I8
,
// 4
BF16_BF16_BF16
,
// 5
};
#define OP_NAME "grouped_conv_bwd_weight"
...
...
@@ -38,7 +39,8 @@ static void print_helper_msg()
<<
" 1: Input fp16, Weight fp16, Output fp16
\n
"
<<
" 2: Input bf16, Weight fp32, Output bf16
\n
"
<<
" 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8
\n
"
<<
" 4: Input int8, Weight int8, Output int8)
\n
"
<<
" 4: Input int8, Weight int8, Output int8
\n
"
<<
" 5: Input bf16, Weight bf16, Output bf16)
\n
"
<<
"arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]
\n
"
<<
" 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
...
...
@@ -180,6 +182,10 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
BF16
{},
F32
{},
BF16
{},
BF16
{},
BF16
{});
}
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I2
,
NHWGC
{},
GKYXC
{},
NHWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NGCHW_GKYXC_NGKHW
)
{
...
...
@@ -187,6 +193,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
return
profile
(
I2
,
NGCHW
{},
GKYXC
{},
NGKHW
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I2
,
NGCHW
{},
GKYXC
{},
NGKHW
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
GNHWC_GKYXC_GNHWK
)
{
...
...
@@ -224,6 +235,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF16
{},
F32
{},
BF16
{},
BF16
{},
BF16
{});
}
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
if
(
data_type
==
ConvDataType
::
F16_F16_F16_BF8_F8
)
{
return
profile
(
I3
,
NDHWGC
{},
GKZYXC
{},
NDHWGK
{},
F16
{},
F16
{},
F16
{},
BF8
{},
F8
{});
...
...
@@ -240,6 +256,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
return
profile
(
I3
,
NGCDHW
{},
GKZYXC
{},
NGKDHW
{},
F16
{},
F16
{},
F16
{},
F16
{},
F16
{});
}
if
(
data_type
==
ConvDataType
::
BF16_BF16_BF16
)
{
return
profile
(
I3
,
NGCDHW
{},
GKZYXC
{},
NGKDHW
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{},
BF16
{});
}
}
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
python/ck4inductor/grouped_conv_fwd/gen_instances.py
0 → 100644
View file @
4d914af3
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
os
import
subprocess
from
dataclasses
import
replace
from
functools
import
lru_cache
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKGroupedConvFwdOp
log
=
logging
.
getLogger
(
__name__
)
def
_ck_conv_instances_path
():
conv_instances_path
=
os
.
path
.
join
(
# noqa: F821
library_path
(),
"include"
,
"ck"
,
"library"
,
"tensor_operation_instance"
,
"gpu"
,
"grouped_conv_fwd"
,
)
if
not
os
.
path
.
exists
(
conv_instances_path
):
log
.
error
(
"CK library conv instances path %s does not exist"
,
conv_instances_path
)
return
None
return
conv_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKGroupedConvFwdOp
]:
"""
Parse the lines containing Grouped Convolution Forward template instances
into `CKGroupedConvFwdOp` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
# TODO: maybe use libclang for parsing C++ code in the future
# to avoid this hacky parsing logic below ? :) - copilot
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
template_args
[
0
]
=
-
1
# n_dim_spatial
template_args
[
3
]
=
tuple
()
# ds_layout
template_args
[
9
]
=
tuple
()
# ds_element_dtype
new_instance
=
CKGroupedConvFwdOp
(
*
template_args
,
# type: ignore[arg-type]
)
op_instances
.
append
(
new_instance
)
return
op_instances
@
lru_cache
(
None
)
def
gen_conv_ops_library
()
->
List
[
CKGroupedConvFwdOp
]:
"""
Parse the Grouped Convolution Forward instances
defined in the Composable Kernel library folder.
"""
ck_library_dir
=
_ck_conv_instances_path
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
,
ck_library_dir
,
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
conv_specs
=
[
"ConvolutionForwardSpecialization::Default"
,
"ConvolutionForwardSpecialization::Filter1x1Pad0"
,
"ConvolutionForwardSpecialization::Filter1x1Stride1Pad0"
,
"ConvolutionForwardSpecialization::OddC"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
(
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
)
sub_spec
=
instance
.
conv_forward_specialization
==
"ConvSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
conv_specs
if
sub_spec
else
[
instance
.
conv_forward_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
for
channels_last
in
[
True
,
False
]:
if
channels_last
:
a_layout
=
"NHWGC"
e_layout
=
"NHWGK"
else
:
a_layout
=
"NGCHW"
e_layout
=
"NGKHW"
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
conv_forward_specialization
=
spec
,
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
n_dim_spatial
=
2
,
a_layout
=
a_layout
,
b_layout
=
"GKYXC"
,
e_layout
=
e_layout
,
)
)
return
substitute_instances
if
__name__
==
"__main__"
:
print
(
gen_conv_ops_library
())
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment