Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a0f2c35
Commit
1a0f2c35
authored
Oct 12, 2023
by
Bartlomiej Kocot
Committed by
Bartłomiej Kocot
Oct 12, 2023
Browse files
Add grouped conv bwd weight wmma
parent
f3b02ecf
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
245 additions
and
5 deletions
+245
-5
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
...ht_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
...nv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
...nv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
+0
-0
test/grouped_convnd_bwd_weight/CMakeLists.txt
test/grouped_convnd_bwd_weight/CMakeLists.txt
+12
-3
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+42
-2
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
..._weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
+191
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
...d_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp
+0
-0
No files found.
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
View file @
1a0f2c35
File moved
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
View file @
1a0f2c35
File moved
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/
xdl/
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
View file @
1a0f2c35
File moved
test/grouped_convnd_bwd_weight/CMakeLists.txt
View file @
1a0f2c35
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list
_xdl
AND target EQUAL 0
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
set
(
target 1
)
endif
()
if
(
gpu IN_LIST gpu_list_wmma AND target EQUAL 0
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
1a0f2c35
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
...
@@ -33,7 +34,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -33,7 +34,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
{
{
// Odd K or C values are supported only by DL kernel (only applies to fp16)
// Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16)
// DL kernel currently supports only `split_k=1`
// DL kernel currently supports only `split_k=1`
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
{
...
@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
}
}
}
}
const
bool
is_navi3x
=
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
;
if
(
is_navi3x
)
{
// on navi3x only support for 3d is implemented
if
constexpr
(
NDimSpatial
{}
!=
3
)
{
return
true
;
}
// on navi3x only support for i8 and fp16 is implemented
if
constexpr
(
!
((
std
::
is_same_v
<
InDataType
,
int8_t
>
&&
std
::
is_same_v
<
WeiDataType
,
int8_t
>
&&
std
::
is_same_v
<
OutDataType
,
int8_t
>
)
||
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
&&
std
::
is_same_v
<
WeiDataType
,
ck
::
half_t
>
&&
std
::
is_same_v
<
OutDataType
,
ck
::
half_t
>
)))
{
return
true
;
}
// WMMA kernel is only supported for split_k=1
if
(
split_k
!=
1
)
{
return
true
;
}
}
else
{
// support for i8 is only implemented on navi3x
if
constexpr
(
std
::
is_same_v
<
InDataType
,
int8_t
>
&&
std
::
is_same_v
<
WeiDataType
,
int8_t
>
&&
std
::
is_same_v
<
OutDataType
,
int8_t
>
)
{
return
true
;
}
}
return
false
;
return
false
;
}
}
...
@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
...
@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>>
;
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight2d
,
KernelTypes2d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight2d
,
KernelTypes2d
);
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
0 → 100644
View file @
1a0f2c35
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include <gtest/gtest.h>
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
ConvolutionBackwardWeightSpecialization
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardWeightSpecialization
;
static
constexpr
auto
ConvBwdWeightDefault
=
ConvolutionBackwardWeightSpecialization
::
Default
;
static
constexpr
auto
Filter1x1Stride1Pad0
=
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
;
template
<
typename
Tuple
,
ConvolutionBackwardWeightSpecialization
ConvSpec
>
class
TestGroupedConvndBwdWeight
:
public
::
testing
::
Test
{
protected:
using
OutLayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
WeiLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
InLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
static
constexpr
ck
::
index_t
NDimSpatial
=
std
::
tuple_element_t
<
3
,
Tuple
>
{};
// clang-format off
using
GroupedConvBwdWeightDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Wmma_CShuffle
//| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
128
,
128
,
128
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
// clang-format on
ck
::
utils
::
conv
::
ConvParam
conv_param
;
template
<
ck
::
index_t
SplitK
>
bool
Run
()
{
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
const
auto
wei_g_k_c_xs_desc
=
ck
::
utils
::
conv
::
make_weight_host_tensor_descriptor_g_k_c_xs_packed
<
WeiLayout
>
(
conv_param
);
const
auto
out_g_n_k_wos_desc
=
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_right_pads_
,
begin
(
input_right_pads
));
auto
conv
=
GroupedConvBwdWeightDeviceInstance
{};
auto
argument
=
conv
.
MakeArgument
(
nullptr
,
nullptr
,
nullptr
,
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{},
SplitK
);
return
conv
.
IsSupportedArgument
(
argument
);
}
};
using
namespace
ck
::
tensor_layout
::
convolution
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
GNDHWK
,
GKZYXC
,
GNDHWC
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
NDHWGK
,
GKZYXC
,
NDHWGC
,
ck
::
Number
<
3
>>>
;
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeightFilter1x13d
:
public
TestGroupedConvndBwdWeight
<
Tuple
,
Filter1x1Stride1Pad0
>
{
};
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeightDefault3d
:
public
TestGroupedConvndBwdWeight
<
Tuple
,
ConvBwdWeightDefault
>
{
};
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeightFilter1x13d
,
KernelTypes3d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeightDefault3d
,
KernelTypes3d
);
TYPED_TEST
(
TestGroupedConvndBwdWeightFilter1x13d
,
SpecializationCheck
)
{
// Check filter 3,3 instead of 1,1
this
->
conv_param
=
{
3
,
2
,
4
,
192
,
192
,
{
3
,
3
,
3
},
{
28
,
28
,
28
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
bool
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// Check strides 2,2 instead of 1,1
this
->
conv_param
=
{
3
,
2
,
4
,
192
,
192
,
{
1
,
1
,
1
},
{
28
,
28
,
28
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// Check with pad
this
->
conv_param
=
{
3
,
2
,
4
,
192
,
192
,
{
1
,
1
,
1
},
{
28
,
28
,
28
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// Supported version
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_TRUE
(
is_supported
);
}
TYPED_TEST
(
TestGroupedConvndBwdWeightDefault3d
,
VectorLoadCheck
)
{
// vector load for A
this
->
conv_param
=
{
3
,
2
,
128
,
129
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
bool
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
// vector load for B, E, Ds
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
257
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_FALSE
(
is_supported
);
}
TYPED_TEST
(
TestGroupedConvndBwdWeightDefault3d
,
SplitKCheck
)
{
// SplitK=1
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
bool
is_supported
=
this
->
template
Run
<
1
>();
EXPECT_TRUE
(
is_supported
);
// SplitK=2
this
->
conv_param
=
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}};
is_supported
=
this
->
template
Run
<
2
>();
EXPECT_FALSE
(
is_supported
);
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
→
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface
_xdl
.cpp
View file @
1a0f2c35
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment