Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4501f13
Commit
a4501f13
authored
Jan 21, 2025
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ck_tile_gemm_policy
parents
c6dcf20d
e7dce4d2
Changes
368
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
160 additions
and
61 deletions
+160
-61
test/ck_tile/gemm/CMakeLists.txt
test/ck_tile/gemm/CMakeLists.txt
+1
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+42
-0
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+5
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+55
-53
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+1
-0
test/data_type/test_bhalf.cpp
test/data_type/test_bhalf.cpp
+48
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+7
-2
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+1
-0
No files found.
test/ck_tile/gemm/CMakeLists.txt
View file @
a4501f13
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_ck_tile_gemm_
mem_
pipeline test_gemm_
mem_
pipeline.cpp
)
add_gtest_executable
(
test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp
)
endif
()
test/ck_tile/gemm/test_gemm_
mem_
pipeline.cpp
→
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
a4501f13
...
...
@@ -6,7 +6,7 @@
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
#include "test_gemm_
mem_
pipeline_util.hpp"
#include "test_gemm_pipeline_util.hpp"
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
...
...
@@ -16,21 +16,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemm
Mem
Pipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmPipeline
,
KernelTypes
);
#include "test_gemm_
mem_
pipeline_ut_cases.inc"
#include "test_gemm_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_
mem_
pipeline_ut_cases.inc
→
test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
View file @
a4501f13
...
...
@@ -3,7 +3,7 @@
#pragma once
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
SmallM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -13,7 +13,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -23,7 +23,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -33,7 +33,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
@@ -43,7 +43,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, Regular)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemm
Mem
Pipeline
,
NotSupportedArgument
)
TYPED_TEST
(
TestCkTileGemmPipeline
,
NotSupportedArgument
)
{
constexpr
int
M
=
512
;
constexpr
int
N
=
1025
;
...
...
test/ck_tile/gemm/test_gemm_
mem_
pipeline_util.hpp
→
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
a4501f13
...
...
@@ -11,36 +11,28 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
enum
struct
GemmPipelineType
{
Mem
,
Comp
};
template
<
typename
Tuple
>
class
TestCkTileGemm
Mem
Pipeline
:
public
::
testing
::
Test
class
TestCkTileGemmPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
std
::
tuple_element_t
<
7
,
Tuple
>::
value
;
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
struct
gemm_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
void
invoke_gemm
(
const
gemm_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
...
@@ -74,10 +66,17 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
...
@@ -85,27 +84,30 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr
bool
has_hot_loop_v
=
has_hot_loop_
.
value
;
constexpr
auto
tail_number_v
=
tail_number_
.
value
;
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
;
using
GemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
@@ -297,11 +299,11 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_a
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
test/data_type/CMakeLists.txt
View file @
a4501f13
...
...
@@ -49,3 +49,4 @@ if(result EQUAL 0)
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_bhalf test_bhalf.cpp
)
test/data_type/test_bhalf.cpp
0 → 100644
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bhalf_t
;
using
ck
::
type_convert
;
TEST
(
BHALF_T
,
Nan
)
{
const
uint16_t
binary_bhalf_nan
=
0x7FC0
;
const
bhalf_t
bhalf_nan
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_nan
);
EXPECT_EQ
(
bhalf_nan
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
QuietNaN
()));
}
TEST
(
BHALF_T
,
Inf
)
{
const
uint16_t
binary_bhalf_inf
=
0x7F80
;
const
bhalf_t
bhalf_inf
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_inf
);
EXPECT_EQ
(
bhalf_inf
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
Infinity
()));
}
TEST
(
BHALF_T
,
MantisaOverflow
)
{
const
float
abs_tol
=
std
::
pow
(
2
,
-
7
);
const
uint32_t
val
=
0x81FFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_NEAR
(
float_val
,
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
abs_tol
);
}
TEST
(
BHALF_T
,
ExpOverflow
)
{
const
uint32_t
val
=
0xFF800000
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_EQ
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
float_val
);
}
TEST
(
BHALF_T
,
MantisaExpOverflow
)
{
const
uint32_t
val
=
0xFFFFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_TRUE
(
std
::
isnan
(
float_val
));
ASSERT_TRUE
(
std
::
isnan
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
))));
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
a4501f13
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
...
...
@@ -43,7 +43,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
return
true
;
}
}
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
())
{
// on gfx11 only support for 3d is implemented
...
...
@@ -143,19 +142,23 @@ using KernelTypes2d = ::testing::Types<
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
float
,
float
,
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
...
...
@@ -179,6 +182,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
{
2
,
2
,
64
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
2
,
64
,
3
,
3
,
{
1
,
1
},
{
7
,
7
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
2
,
64
,
5
,
5
,
{
1
,
1
},
{
7
,
7
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
a4501f13
...
...
@@ -64,6 +64,7 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
int8_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
...
...
Prev
1
…
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment