Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a8efb3f0
Unverified
Commit
a8efb3f0
authored
Jul 16, 2024
by
Illia Silin
Committed by
GitHub
Jul 16, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
06701e70
9cac2827
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
352 additions
and
22 deletions
+352
-22
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
..._fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
+48
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt
..._operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
..._xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+47
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
...d_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+47
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
...d_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+47
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
..._instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
+2
-1
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
...wd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
+62
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
...ance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
...d_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
+63
-0
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+18
-20
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+9
-1
No files found.
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt
View file @
a8efb3f0
...
...
@@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwdDefault
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
ConvFwd3x3
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
View file @
a8efb3f0
...
...
@@ -2,6 +2,7 @@
set
(
GROUPED_CONV3D_FWD_CONVSCALE
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp
)
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_instance
${
GROUPED_CONV3D_FWD_CONVSCALE
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
ConvScale
=
ck
::
tensor_operation
::
element_wise
::
ConvScale
;
void
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
BF8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScale
,
BF8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
ConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
ConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
ConvScale
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt
0 → 100644
View file @
a8efb3f0
# ONLY XDL_KERNELS
set
(
GROUPED_CONV3D_FWD_CONVSCALE_RELU
xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_relu_instance
${
GROUPED_CONV3D_FWD_CONVSCALE_RELU
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
0 → 100644
View file @
a8efb3f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
ConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ConvScaleRelu
;
void
add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScaleRelu
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
ConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
ConvScaleRelu
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
ConvScaleRelu
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
a8efb3f0
...
...
@@ -191,7 +191,24 @@ bool profile_gemm_universal_impl(int do_verification,
{
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
CDataType
,
f8_t
>
)
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
}
#endif
if
(
do_log
)
{
...
...
@@ -230,25 +247,6 @@ bool profile_gemm_universal_impl(int do_verification,
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
CDataType
,
f8_t
>
)
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#if defined CK_ENABLE_FP8
}
#endif
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
a8efb3f0
...
...
@@ -104,6 +104,7 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D)
this
->
conv_params
.
push_back
({
1
,
2
,
32
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
32
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
64
,
3
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
96
,
1
,
1
,
1
,
{
3
},
{
512
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
template
Run
<
1
>();
}
...
...
@@ -119,6 +120,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D)
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
96
,
1
,
1
,
1
,
{
3
,
3
},
{
120
,
160
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
template
Run
<
2
>();
}
...
...
@@ -137,6 +140,8 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D)
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
96
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
4
,
30
,
160
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
template
Run
<
3
>();
}
...
...
@@ -144,6 +149,9 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
{
// Case larger than 2GB
this
->
conv_params
.
push_back
(
{
2
,
1
,
64
,
4
,
192
,
{
2
,
2
},
{
224
,
224
},
{
224
,
224
},
{
0
,
0
},
{
0
,
0
},
{
0
,
0
}});
{
2
,
1
,
64
,
4
,
192
,
{
2
,
2
},
{
224
,
224
},
{
224
,
224
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
// With supported NumGroupsToMerge > 1
this
->
conv_params
.
push_back
(
{
2
,
32
,
64
,
1
,
1
,
{
2
,
2
},
{
672
,
672
},
{
672
,
672
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>();
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment