Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
408534d4
Unverified
Commit
408534d4
authored
Aug 09, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Aug 09, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
a8efb3f0
da214a5a
Changes
204
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
521 additions
and
129 deletions
+521
-129
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
...iversal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
...niversal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
...niversal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
...iversal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
...ersal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
...niversal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
...iversal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
...ersal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
..._operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
+4
-4
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
..._fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
+6
-15
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
...d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
+8
-17
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
...d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
+11
-20
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt
..._operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt
+3
-3
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
...d_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+6
-14
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
...wd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+8
-16
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
...wd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+39
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt
...tance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
...wd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
+36
-20
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
+55
-0
library/src/utility/convolution_parameter.cpp
library/src/utility/convolution_parameter.cpp
+78
-20
No files found.
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances
<
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances
<
GemmMNPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances
<
Intrawave
,
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances
<
Interwave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances
<
Interwave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances
<
Interwave
,
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
View file @
408534d4
...
...
@@ -9,11 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
#
me
rge
d groups
#
la
rge
tensor
# NHWGC, GKYXC, NHWGK
xdl/
me
rge
d_groups
/device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/
me
rge
d_groups
/device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/
me
rge
d_groups
/device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/
la
rge
_tensor
/device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/
la
rge
_tensor
/device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/
la
rge
_tensor
/device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f32_instance.cpp
#mem
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/
me
rge
d_groups
/device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_
f32
_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/
la
rge
_tensor
/device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_
bf16
_instance.cpp
View file @
408534d4
...
...
@@ -2,44 +2,35 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
me
rge
d_groups
_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
la
rge
_tensor
_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_
f32
_instances
(
void
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_
bf16
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
BF16
,
BF16
,
Empty_Tuple
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_
me
rge
d_groups_f32
_instances
<
2
,
device_grouped_conv_fwd_xdl_
la
rge
_tensor_bf16
_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/
me
rge
d_groups
/device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/
la
rge
_tensor
/device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f16_instance.cpp
View file @
408534d4
...
...
@@ -2,14 +2,14 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
me
rge
d_groups
_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
la
rge
_tensor
_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -25,21 +25,12 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd3x3
>
{});
device_grouped_conv_fwd_xdl_large_tensor_f16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/
me
rge
d_groups
/device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_
bf16
_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/
la
rge
_tensor
/device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_
f32
_instance.cpp
View file @
408534d4
...
...
@@ -2,44 +2,35 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
me
rge
d_groups
_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
la
rge
_tensor
_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_
me
rge
d_groups
_nhwgc_gkyxc_nhwgk_
bf16
_instances
(
void
add_device_grouped_conv2d_fwd_xdl_
la
rge
_tensor
_nhwgc_gkyxc_nhwgk_
f32
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
F32
,
F32
,
Empty_Tuple
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd3x3
>
{});
device_grouped_conv_fwd_xdl_large_tensor_f32_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt
View file @
408534d4
...
...
@@ -9,9 +9,9 @@ set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
xdl/
me
rge
d_groups
/device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/
me
rge
d_groups
/device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/
me
rge
d_groups
/device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/
la
rge
_tensor
/device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/
la
rge
_tensor
/device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/
la
rge
_tensor
/device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/
me
rge
d_groups
/device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_
f32
_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/
la
rge
_tensor
/device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_
bf16
_instance.cpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
me
rge
d_groups
_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
la
rge
_tensor
_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -9,36 +9,28 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_
f32
_instances
(
void
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_
bf16
_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
BF16
,
BF16
,
Empty_Tuple
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_
me
rge
d_groups_f32
_instances
<
3
,
device_grouped_conv_fwd_xdl_
la
rge
_tensor_bf16
_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/
me
rge
d_groups
/device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/
la
rge
_tensor
/device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
me
rge
d_groups
_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
la
rge
_tensor
_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
...
...
@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_
me
rge
d_groups
_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_fwd_xdl_
la
rge
_tensor
_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
...
...
@@ -25,20 +25,12 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_in
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwd3x3
>
{});
device_grouped_conv_fwd_xdl_large_tensor_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
0 → 100644
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_large_tensor_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt
0 → 100644
View file @
408534d4
# ONLY XDL_KERNELS
set
(
GROUPED_CONV3D_FWD_CONVSCALE_ADD
xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_add_instance
${
GROUPED_CONV3D_FWD_CONVSCALE_ADD
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd
/xdl/merged_groups
/device_grouped_conv3d_fwd_xdl_
merged_groups
_ndhwgc_gkzyxc_ndhwgk_
bf16
_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd
_convscale_add/xdl
/device_grouped_conv3d_fwd_xdl_
convscale_add
_ndhwgc_gkzyxc_ndhwgk_
f8
_instance.cpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
merged_groups
_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_
binary_outelementop
_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
using
ConvScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleAdd
;
void
add_device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
F8
,
F8
,
ck
::
Tuple
<
F32
>
,
F8
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
ConvScaleAdd
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
device_grouped_conv_fwd_xdl_binary_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
ConvFwdDefault
,
ConvScaleAdd
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_binary_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
ConvFwd1x1P0
,
ConvScaleAdd
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwd3x3
>
{});
device_grouped_conv_fwd_xdl_binary_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<
NDHWGK
>
,
NDHWGK
,
ConvFwd1x1S1P0
,
ConvScaleAdd
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
0 → 100644
View file @
408534d4
set
(
FMHA_CPP_FOLDER
${
CMAKE_CURRENT_BINARY_DIR
}
)
set
(
FMHA_SRC_FOLDER
${
CMAKE_SOURCE_DIR
}
/example/ck_tile/01_fmha/
)
set
(
CK_TILE_SRC_FOLDER
${
CMAKE_SOURCE_DIR
}
/include/ck_tile/
)
# python stuff
find_package
(
PythonInterp 3 REQUIRED
)
rocm_install
(
DIRECTORY
${
CK_TILE_SRC_FOLDER
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck_tile
)
rocm_install
(
FILES
"
${
FMHA_SRC_FOLDER
}
/fmha_fwd.hpp"
"
${
FMHA_SRC_FOLDER
}
/bias.hpp"
"
${
FMHA_SRC_FOLDER
}
/mask.hpp"
DESTINATION include/ck_tile/ops
)
# header for building lib
file
(
COPY
${
FMHA_SRC_FOLDER
}
/fmha_fwd.hpp DESTINATION
${
FMHA_CPP_FOLDER
}
)
file
(
COPY
${
FMHA_SRC_FOLDER
}
/bias.hpp DESTINATION
${
FMHA_CPP_FOLDER
}
)
file
(
COPY
${
FMHA_SRC_FOLDER
}
/mask.hpp DESTINATION
${
FMHA_CPP_FOLDER
}
)
# generate a list of kernels, but not actually emit files at config stage
execute_process
(
COMMAND
${
PYTHON_EXECUTABLE
}
${
CMAKE_SOURCE_DIR
}
/example/ck_tile/01_fmha/generate.py
--list_blobs
${
FMHA_CPP_FOLDER
}
/blob_list.txt
)
file
(
STRINGS
${
FMHA_CPP_FOLDER
}
/blob_list.txt FMHA_FWD_GEN_BLOBS
)
# actually generate the cpp files
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
CMAKE_SOURCE_DIR
}
/example/ck_tile/01_fmha/generate.py
--output_dir
${
FMHA_CPP_FOLDER
}
COMMENT
"Generating mha kernel (cpp) files now ..."
VERBATIM
)
# This is done to remove path info and just
# have filename. Since, it was cauing the cmake
# to throw "File name too long"
set
(
device_files
)
foreach
(
filepath IN LISTS FMHA_FWD_GEN_BLOBS
)
get_filename_component
(
filename
${
filepath
}
NAME
)
# Append the filename to the device_files list
list
(
APPEND device_files
${
filename
}
)
endforeach
()
add_custom_target
(
generate_cpp_files DEPENDS
${
FMHA_FWD_GEN_BLOBS
}
)
add_instance_library
(
device_mha_instance
${
device_files
}
)
if
(
TARGET device_mha_instance
)
add_dependencies
(
device_mha_instance generate_cpp_files
)
endif
()
library/src/utility/convolution_parameter.cpp
View file @
408534d4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/io.hpp"
...
...
@@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim,
const
std
::
vector
<
ck
::
index_t
>&
dilations
,
const
std
::
vector
<
ck
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
right_pads
)
:
num_dim_spatial_
(
static_cast
<
ck
::
long_index_t
>
(
n_dim
)),
G_
(
static_cast
<
ck
::
long_index_t
>
(
group_count
)),
N_
(
static_cast
<
ck
::
long_index_t
>
(
n_batch
)),
K_
(
static_cast
<
ck
::
long_index_t
>
(
n_out_channels
)),
C_
(
static_cast
<
ck
::
long_index_t
>
(
n_in_channels
)),
filter_spatial_lengths_
(
num_dim_spatial_
),
input_spatial_lengths_
(
num_dim_spatial_
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
num_dim_spatial_
),
conv_filter_dilations_
(
num_dim_spatial_
),
input_left_pads_
(
num_dim_spatial_
),
input_right_pads_
(
num_dim_spatial_
)
{
if
(
static_cast
<
ck
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
filter_spatial_lengths_
[
i
]
=
static_cast
<
ck
::
long_index_t
>
(
filters_len
[
i
]);
input_spatial_lengths_
[
i
]
=
static_cast
<
ck
::
long_index_t
>
(
input_len
[
i
]);
conv_filter_strides_
[
i
]
=
static_cast
<
ck
::
long_index_t
>
(
strides
[
i
]);
conv_filter_dilations_
[
i
]
=
static_cast
<
ck
::
long_index_t
>
(
dilations
[
i
]);
input_left_pads_
[
i
]
=
static_cast
<
ck
::
long_index_t
>
(
left_pads
[
i
]);
input_right_pads_
[
i
]
=
static_cast
<
ck
::
long_index_t
>
(
right_pads
[
i
]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ConvParam
::
ConvParam
(
ck
::
long_index_t
n_dim
,
ck
::
long_index_t
group_count
,
ck
::
long_index_t
n_batch
,
ck
::
long_index_t
n_out_channels
,
ck
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
long_index_t
>&
right_pads
)
:
num_dim_spatial_
(
n_dim
),
G_
(
group_count
),
N_
(
n_batch
),
...
...
@@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim,
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck
::
index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
const
ck
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
...
...
@@ -63,7 +121,7 @@ ConvParam::ConvParam()
{
}
std
::
vector
<
ck
::
index_t
>
ConvParam
::
GetOutputSpatialLengths
()
const
std
::
vector
<
ck
::
long_
index_t
>
ConvParam
::
GetOutputSpatialLengths
()
const
{
return
output_spatial_lengths_
;
}
...
...
@@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg()
ck
::
utils
::
conv
::
ConvParam
parse_conv_param
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
{
const
ck
::
index_t
G
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
N
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
K
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
C
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
(
num_dim_spatial
);
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
(
num_dim_spatial
);
std
::
vector
<
ck
::
index_t
>
input_left_pads
(
num_dim_spatial
);
std
::
vector
<
ck
::
index_t
>
input_right_pads
(
num_dim_spatial
);
const
ck
::
long_
index_t
G
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
const
ck
::
long_
index_t
N
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
const
ck
::
long_
index_t
K
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
const
ck
::
long_
index_t
C
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
std
::
vector
<
ck
::
long_
index_t
>
filter_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck
::
long_
index_t
>
input_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_strides
(
num_dim_spatial
);
std
::
vector
<
ck
::
long_
index_t
>
conv_filter_dilations
(
num_dim_spatial
);
std
::
vector
<
ck
::
long_
index_t
>
input_left_pads
(
num_dim_spatial
);
std
::
vector
<
ck
::
long_
index_t
>
input_right_pads
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
filter_spatial_lengths
[
i
]
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
filter_spatial_lengths
[
i
]
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_spatial_lengths
[
i
]
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
input_spatial_lengths
[
i
]
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_strides
[
i
]
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
conv_filter_strides
[
i
]
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_dilations
[
i
]
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
conv_filter_dilations
[
i
]
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_left_pads
[
i
]
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
input_left_pads
[
i
]
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_right_pads
[
i
]
=
std
::
sto
i
(
argv
[
arg_idx
++
]);
input_right_pads
[
i
]
=
std
::
sto
l
(
argv
[
arg_idx
++
]);
}
return
ck
::
utils
::
conv
::
ConvParam
{
num_dim_spatial
,
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment