Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b924e330
Commit
b924e330
authored
Oct 03, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
72c9f129
9c0811f3
Changes
153
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1111 additions
and
46 deletions
+1111
-46
profiler/src/profile_pool3d_fwd.cpp
profiler/src/profile_pool3d_fwd.cpp
+331
-0
script/convert_miopen_driver_to_profiler.py
script/convert_miopen_driver_to_profiler.py
+2
-0
test/gemm_universal/test_gemm_universal_ut_cases.inc
test/gemm_universal/test_gemm_universal_ut_cases.inc
+128
-0
test/gemm_universal/test_gemm_universal_xdl.cpp
test/gemm_universal/test_gemm_universal_xdl.cpp
+25
-0
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+3
-1
test/pool/CMakeLists.txt
test/pool/CMakeLists.txt
+12
-0
test/pool/test_avg_pool2d_bwd.cpp
test/pool/test_avg_pool2d_bwd.cpp
+133
-0
test/pool/test_avg_pool2d_fwd.cpp
test/pool/test_avg_pool2d_fwd.cpp
+145
-0
test/pool/test_avg_pool3d_fwd.cpp
test/pool/test_avg_pool3d_fwd.cpp
+18
-17
test/pool/test_max_pool2d_bwd.cpp
test/pool/test_max_pool2d_bwd.cpp
+139
-0
test/pool/test_max_pool2d_fwd.cpp
test/pool/test_max_pool2d_fwd.cpp
+150
-0
test/pool/test_max_pool3d_fwd.cpp
test/pool/test_max_pool3d_fwd.cpp
+19
-27
test/pool/test_pool_fwd_common.hpp
test/pool/test_pool_fwd_common.hpp
+6
-1
No files found.
profiler/src/profile_pool3d_fwd.cpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_pool3d_fwd_impl.hpp"
#include "profiler_operation_registry.hpp"
using
ck
::
index_t
;
struct
poolFwdArgParser
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{{
"length"
,
{}},
{
"wsize"
,
{}},
{
"wstride"
,
{}},
{
"wdilation"
,
{}},
{
"pad1"
,
{}},
{
"pad2"
,
{}}};
bool
parse_opt
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
key
,
int
i
)
{
if
(
std
::
string
(
"--"
)
+
key
==
argv
[
i
])
{
int
pos
=
i
;
while
(
++
i
<
argc
&&
argv
[
i
][
0
]
!=
'-'
)
{}
int
end
=
i
;
for
(
int
j
=
pos
+
1
;
j
<
end
;
j
++
)
{
long_opts
[
key
].
push_back
(
std
::
stoi
(
argv
[
j
]));
}
return
true
;
}
return
false
;
}
void
operator
()(
int
argc
,
char
*
argv
[])
{
for
(
auto
&
kv
:
long_opts
)
{
for
(
int
i
=
1
;
i
<
argc
;
i
++
)
{
if
(
parse_opt
(
argc
,
argv
,
kv
.
first
,
i
))
break
;
}
}
}
};
void
print_help_pool3d_fwd
()
{
std
::
cout
<<
"arg1: data type (0: fp16; 1: fp32; 3: int8; 5: bf16; 7: fp8)
\n
"
<<
"arg2: verification (0: no; 1: yes)
\n
"
<<
"arg3: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg4: print tensor value (0: no; 1: yes)
\n
"
<<
"arg5: time kernel (0=no, 1=yes)
\n
"
<<
"arg6: return index (0=no, 1=yes)
\n
"
<<
"arg7: reduce op (0: max; 1: avg)
\n
"
<<
"--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30)
\n
"
<<
"--wsize: window size for ZYX (e.g, --wsize 2 2 2)
\n
"
<<
"--wstride: window stride for DHW (e.g, --wstride 2 2 2)
\n
"
<<
"--wdilation: window dilation for DHW (e.g, --wdilation 1 1 1)
\n
"
<<
"--pad1: left side of padding in DHW (e.g, --pad1 1 1 1)
\n
"
<<
"--pad2: right side of padding in DHW (e.g, --pad2 1 1 1)
\n
"
<<
"eg: ckProfiler pool3d_fwd 0 1 2 0 1 0 --length 2 32 30 30 30 --wsize 2 2 2 "
"--wstride 2 2 2 --wdilation 1 1 1 --pad1 1 1 1 --pad2 1 1 1"
<<
std
::
endl
;
}
int
profile_pool3d_fwd
(
int
argc
,
char
*
argv
[])
{
ck
::
DataTypeEnum
data_type
=
ck
::
DataTypeEnum
::
Half
;
ck
::
profiler
::
PoolFwdInputParams
in_params
{
true
,
0
,
false
,
true
,
false
,
0
};
ck
::
profiler
::
PoolFwdKernelParams
kernel_params
{
{
2
,
32
,
30
,
30
,
30
},
{
2
,
2
,
2
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}};
if
(
argc
!=
2
&&
argc
!=
35
)
{
print_help_pool3d_fwd
();
return
0
;
}
else
if
(
argc
==
35
)
{
data_type
=
static_cast
<
ck
::
DataTypeEnum
>
(
std
::
stoi
(
argv
[
2
]));
in_params
.
do_verification
=
std
::
stoi
(
argv
[
3
]);
in_params
.
init_method
=
std
::
stoi
(
argv
[
4
]);
in_params
.
do_log
=
std
::
stoi
(
argv
[
5
]);
in_params
.
time_kernel
=
std
::
stoi
(
argv
[
6
]);
in_params
.
return_index
=
std
::
stoi
(
argv
[
7
]);
in_params
.
reduce_op
=
std
::
stoi
(
argv
[
8
]);
// parse the long options
poolFwdArgParser
arg_parser
;
arg_parser
(
argc
,
argv
);
kernel_params
.
in_length
=
arg_parser
.
long_opts
[
"length"
];
kernel_params
.
window_spatial_lengths
=
arg_parser
.
long_opts
[
"wsize"
];
kernel_params
.
window_strides
=
arg_parser
.
long_opts
[
"wstride"
];
kernel_params
.
window_dilations
=
arg_parser
.
long_opts
[
"wdilation"
];
kernel_params
.
input_left_pads
=
arg_parser
.
long_opts
[
"pad1"
];
kernel_params
.
input_right_pads
=
arg_parser
.
long_opts
[
"pad2"
];
}
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
F8
=
ck
::
f8_t
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
{
if
(
in_params
.
reduce_op
==
1
)
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
>
(
in_params
,
kernel_params
);
}
else
{
// reduce_op == 0
if
(
in_params
.
return_index
)
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
true
>
(
in_params
,
kernel_params
);
}
else
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
>
(
in_params
,
kernel_params
);
}
}
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
BFloat16
)
{
if
(
in_params
.
reduce_op
==
1
)
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
BF16
,
BF16
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
>
(
in_params
,
kernel_params
);
}
else
{
// reduce_op == 0
if
(
in_params
.
return_index
)
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
BF16
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
true
>
(
in_params
,
kernel_params
);
}
else
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
BF16
,
BF16
,
BF16
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
>
(
in_params
,
kernel_params
);
}
}
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
{
if
(
in_params
.
reduce_op
==
1
)
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
>
(
in_params
,
kernel_params
);
}
else
{
// reduce_op == 0
if
(
in_params
.
return_index
)
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
true
>
(
in_params
,
kernel_params
);
}
else
{
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
>
(
in_params
,
kernel_params
);
}
}
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float8
)
{
if
(
in_params
.
reduce_op
==
1
)
{
return
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F8
,
F8
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
>
(
in_params
,
kernel_params
);
}
else
{
// reduce_op == 0
if
(
in_params
.
return_index
)
{
return
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F8
,
F8
,
F8
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
true
>
(
in_params
,
kernel_params
);
}
else
{
return
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F8
,
F8
,
F8
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
>
(
in_params
,
kernel_params
);
}
}
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Int8
)
{
if
(
in_params
.
reduce_op
==
1
)
{
return
ck
::
profiler
::
profile_pool3d_fwd_impl
<
I8
,
I8
,
I32
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
>
(
in_params
,
kernel_params
);
}
else
{
// reduce_op == 0
if
(
in_params
.
return_index
)
{
return
ck
::
profiler
::
profile_pool3d_fwd_impl
<
I8
,
I8
,
I8
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
true
>
(
in_params
,
kernel_params
);
}
else
{
return
ck
::
profiler
::
profile_pool3d_fwd_impl
<
I8
,
I8
,
I8
,
I32
,
NDHWC
,
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
>
(
in_params
,
kernel_params
);
}
}
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
return
0
;
}
REGISTER_PROFILER_OPERATION
(
"pool3d_fwd"
,
"pool3d fwd"
,
profile_pool3d_fwd
);
script/convert_miopen_driver_to_profiler.py
View file @
b924e330
...
@@ -28,6 +28,8 @@ def parse_layouts(args):
...
@@ -28,6 +28,8 @@ def parse_layouts(args):
args
.
in_layout
==
"NCDHW"
:
args
.
in_layout
==
"NCDHW"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
:
if
args
.
ck_profier_op
==
"grouped_conv_bwd_weight"
:
args
.
layout
=
3
args
.
layout
=
3
elif
args
.
ck_profier_op
==
"grouped_conv_fwd"
:
args
.
layout
=
2
else
:
else
:
print
(
'Not supported layout for this op'
)
print
(
'Not supported layout for this op'
)
exit
(
1
)
exit
(
1
)
...
...
test/gemm_universal/test_gemm_universal_ut_cases.inc
View file @
b924e330
...
@@ -28,6 +28,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
...
@@ -28,6 +28,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_KN
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
512
;
constexpr
int
K
=
320
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_NK
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
512
;
constexpr
int
K
=
320
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_MK_KN
,
MidLargeM
)
TYPED_TEST
(
TestGemmUniversal_MK_KN
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
...
@@ -56,6 +88,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
...
@@ -56,6 +88,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_KN
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
512
;
constexpr
int
K
=
320
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_NK
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
512
;
constexpr
int
K
=
320
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_MK_KN
,
PaddK
)
TYPED_TEST
(
TestGemmUniversal_MK_KN
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
127
};
std
::
vector
<
int
>
Ms
{
127
};
...
@@ -84,6 +148,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
...
@@ -84,6 +148,38 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_KN
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
512
;
constexpr
int
K
=
437
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_NK
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
512
;
constexpr
int
K
=
437
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_MK_KN
,
Regular
)
TYPED_TEST
(
TestGemmUniversal_MK_KN
,
Regular
)
{
{
std
::
vector
<
int
>
Ms
{
512
};
std
::
vector
<
int
>
Ms
{
512
};
...
@@ -111,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
...
@@ -111,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular)
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_KN
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
512
;
constexpr
int
K
=
512
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
TYPED_TEST
(
TestGemmUniversal_KM_NK
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
512
;
constexpr
int
K
=
512
;
constexpr
int
StrideB
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
{
int
StrideA
=
M
;
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
}
test/gemm_universal/test_gemm_universal_xdl.cpp
View file @
b924e330
...
@@ -40,6 +40,18 @@ class TestGemmUniversal_MK_NK
...
@@ -40,6 +40,18 @@ class TestGemmUniversal_MK_NK
{
{
};
};
template
<
typename
Tuple
>
class
TestGemmUniversal_KM_KN
:
public
ck
::
test
::
TestGemmUniversal
<
typename
tuple_concat
<
std
::
tuple
<
Col
,
Row
>
,
Tuple
>::
type
>
{
};
template
<
typename
Tuple
>
class
TestGemmUniversal_KM_NK
:
public
ck
::
test
::
TestGemmUniversal
<
typename
tuple_concat
<
std
::
tuple
<
Col
,
Col
>
,
Tuple
>::
type
>
{
};
// clang-format off
// clang-format off
using
KernelTypes_MK_KN
=
::
testing
::
Types
<
using
KernelTypes_MK_KN
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
// ADataType, BDataType, ComputeDataType, CDataType
...
@@ -61,9 +73,22 @@ using KernelTypes_MK_NK = ::testing::Types<
...
@@ -61,9 +73,22 @@ using KernelTypes_MK_NK = ::testing::Types<
#endif
#endif
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
>
;
>
;
using
KernelTypes_KM_NK
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
>
;
using
KernelTypes_KM_KN
=
::
testing
::
Types
<
// ADataType, BDataType, ComputeDataType, CDataType
std
::
tuple
<
BF16
,
BF16
,
BF16
,
BF16
>
>
;
// clang-format on
// clang-format on
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_KN
,
KernelTypes_MK_KN
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_KN
,
KernelTypes_MK_KN
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_NK
,
KernelTypes_MK_NK
);
TYPED_TEST_SUITE
(
TestGemmUniversal_MK_NK
,
KernelTypes_MK_NK
);
TYPED_TEST_SUITE
(
TestGemmUniversal_KM_KN
,
KernelTypes_KM_KN
);
TYPED_TEST_SUITE
(
TestGemmUniversal_KM_NK
,
KernelTypes_KM_NK
);
#include "test_gemm_universal_ut_cases.inc"
#include "test_gemm_universal_ut_cases.inc"
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
b924e330
...
@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
...
@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>>
;
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
std
::
tuple
<
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
...
...
test/pool/CMakeLists.txt
View file @
b924e330
...
@@ -4,13 +4,25 @@ add_gtest_executable(test_avg_pool3d_bwd test_avg_pool3d_bwd.cpp)
...
@@ -4,13 +4,25 @@ add_gtest_executable(test_avg_pool3d_bwd test_avg_pool3d_bwd.cpp)
add_gtest_executable
(
test_max_pool3d_bwd test_max_pool3d_bwd.cpp
)
add_gtest_executable
(
test_max_pool3d_bwd test_max_pool3d_bwd.cpp
)
add_gtest_executable
(
test_avg_pool3d_fwd test_avg_pool3d_fwd.cpp
)
add_gtest_executable
(
test_avg_pool3d_fwd test_avg_pool3d_fwd.cpp
)
add_gtest_executable
(
test_max_pool3d_fwd test_max_pool3d_fwd.cpp
)
add_gtest_executable
(
test_max_pool3d_fwd test_max_pool3d_fwd.cpp
)
add_gtest_executable
(
test_avg_pool2d_bwd test_avg_pool2d_bwd.cpp
)
add_gtest_executable
(
test_max_pool2d_bwd test_max_pool2d_bwd.cpp
)
add_gtest_executable
(
test_avg_pool2d_fwd test_avg_pool2d_fwd.cpp
)
add_gtest_executable
(
test_max_pool2d_fwd test_max_pool2d_fwd.cpp
)
target_link_libraries
(
test_avg_pool3d_bwd PRIVATE utility device_avg_pool3d_bwd_instance
)
target_link_libraries
(
test_avg_pool3d_bwd PRIVATE utility device_avg_pool3d_bwd_instance
)
target_link_libraries
(
test_avg_pool2d_bwd PRIVATE utility device_avg_pool2d_bwd_instance
)
target_link_libraries
(
test_max_pool2d_bwd PRIVATE utility device_max_pool_bwd_instance
)
target_link_libraries
(
test_max_pool3d_bwd PRIVATE utility device_max_pool_bwd_instance
)
target_link_libraries
(
test_max_pool3d_bwd PRIVATE utility device_max_pool_bwd_instance
)
target_link_libraries
(
test_avg_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance
)
target_link_libraries
(
test_avg_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance
)
target_link_libraries
(
test_max_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance
)
target_link_libraries
(
test_max_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance
)
target_link_libraries
(
test_avg_pool2d_fwd PRIVATE utility device_pool2d_fwd_instance
)
target_link_libraries
(
test_max_pool2d_fwd PRIVATE utility device_pool2d_fwd_instance
)
add_dependencies
(
test_pool test_avg_pool3d_bwd
)
add_dependencies
(
test_pool test_avg_pool3d_bwd
)
add_dependencies
(
test_pool test_max_pool3d_bwd
)
add_dependencies
(
test_pool test_max_pool3d_bwd
)
add_dependencies
(
test_pool test_avg_pool3d_fwd
)
add_dependencies
(
test_pool test_avg_pool3d_fwd
)
add_dependencies
(
test_pool test_max_pool3d_fwd
)
add_dependencies
(
test_pool test_max_pool3d_fwd
)
add_dependencies
(
test_pool test_avg_pool2d_bwd
)
add_dependencies
(
test_pool test_max_pool2d_bwd
)
add_dependencies
(
test_pool test_avg_pool2d_fwd
)
add_dependencies
(
test_pool test_max_pool2d_fwd
)
test/pool/test_avg_pool2d_bwd.cpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_avg_pool2d_bwd_impl.hpp"
#include "test_pool_fwd_common.hpp"
template
<
typename
T
>
class
AvgPool2dBWDTest
:
public
::
testing
::
Test
{
protected:
using
InDataType
=
std
::
tuple_element_t
<
0
,
T
>
;
using
OutDataType
=
std
::
tuple_element_t
<
1
,
T
>
;
static
std
::
vector
<
PoolingParam
>
params
;
void
Run
()
{
for
(
auto
param
:
this
->
params
)
{
bool
success
=
ck
::
profiler
::
profile_avg_pool2d_bwd_impl
<
InDataType
,
OutDataType
,
NHWC
,
NHWC
>
(
true
,
2
,
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
}
}
};
template
<
typename
T
>
std
::
vector
<
PoolingParam
>
AvgPool2dBWDTest
<
T
>::
params
=
{
{{
1
,
1
,
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
1
,
1
,
64
,
64
},
{
64
,
64
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
1
,
5
,
7
,
7
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
2
,
2
},
{
0
,
0
}},
{{
1
,
1
,
8
,
8
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
2
,
2
},
{
0
,
0
}},
{{
1
,
1
,
8
,
8
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
}},
{{
2
,
32
,
30
,
30
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}},
{{
1
,
2
,
30
,
30
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}}};
using
Avg_Pool_2D_f32_types
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
>>
;
using
Avg_Pool_2D_int8_types
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
>>
;
using
Avg_Pool_2D_f16_types
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
>>
;
using
Avg_Pool_2D_bf16_types
=
::
testing
::
Types
<
std
::
tuple
<
BF16
,
BF16
>>
;
using
Avg_Pool_2D_f8_types
=
::
testing
::
Types
<
std
::
tuple
<
F8
,
F8
>>
;
template
<
typename
TType
>
class
AvgPool2D_f32
:
public
AvgPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP32
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_f32 tests because CK_ENABLE_FP32 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_int8
:
public
AvgPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_INT8
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_int8 tests because CK_ENABLE_INT8 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_f16
:
public
AvgPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP16
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_f16 because CK_ENABLE_FP16 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_bf16
:
public
AvgPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_BF16
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_bf16 tests because CK_ENABLE_BF16 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_f8
:
public
AvgPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP8
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_f8 tests because CK_ENABLE_FP8 is not enabled"
;
}
}
};
TYPED_TEST_SUITE
(
AvgPool2D_f32
,
Avg_Pool_2D_f32_types
);
TYPED_TEST_SUITE
(
AvgPool2D_int8
,
Avg_Pool_2D_int8_types
);
TYPED_TEST_SUITE
(
AvgPool2D_f16
,
Avg_Pool_2D_f16_types
);
TYPED_TEST_SUITE
(
AvgPool2D_bf16
,
Avg_Pool_2D_bf16_types
);
TYPED_TEST_SUITE
(
AvgPool2D_f8
,
Avg_Pool_2D_f8_types
);
TYPED_TEST
(
AvgPool2D_f32
,
AvgPool2DTest_f32
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_int8
,
AvgPool2DTest_int8
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_f16
,
AvgPool2DTest_f16
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_bf16
,
AvgPool2DTest_bf16
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_f8
,
AvgPool2DTest_f8
)
{
this
->
Run
();
}
test/pool/test_avg_pool2d_fwd.cpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_pool2d_fwd_impl.hpp"
#include "test_pool_fwd_common.hpp"
template
<
typename
Tuple
>
class
TestAvgPool2dFwd
:
public
::
testing
::
Test
{
protected:
using
InDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
OutDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
IndexDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
static
std
::
vector
<
PoolingParam
>
params
;
void
Run
()
{
for
(
auto
param
:
params
)
{
bool
success
=
ck
::
profiler
::
profile_pool2d_fwd_impl
<
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
>
(
true
,
2
,
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
}
}
};
template
<
typename
T
>
std
::
vector
<
PoolingParam
>
TestAvgPool2dFwd
<
T
>::
params
=
{
{{{
1
,
1
,
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
16
,
64
,
64
},
{
64
,
64
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
16
,
64
,
64
},
{
4
,
4
},
{
4
,
4
},
{
2
,
2
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
32
,
30
,
30
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}}}};
using
AvgPool2D_F32_Types
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
using
AvgPool2D_F16_Types
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>>
;
using
AvgPool2D_BF16_Types
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
F32
,
I32
>
,
std
::
tuple
<
BF16
,
BF16
,
F32
,
I32
>>
;
using
AvgPool2D_I8_Types
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
F32
,
I32
>
,
std
::
tuple
<
I8
,
I8
,
F32
,
I32
>>
;
using
AvgPool2D_F8_Types
=
::
testing
::
Types
<
std
::
tuple
<
F8
,
F8
,
F32
,
I32
>
,
std
::
tuple
<
F8
,
F8
,
F32
,
I32
>>
;
template
<
typename
TType
>
class
AvgPool2D_F32
:
public
TestAvgPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP32
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_F16
:
public
TestAvgPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP16
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_BF16
:
public
TestAvgPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_BF16
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_I8
:
public
TestAvgPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_INT8
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
AvgPool2D_F8
:
public
TestAvgPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP8
)
{
GTEST_SKIP
()
<<
"Skipping AvgPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled"
;
}
}
};
TYPED_TEST_SUITE
(
AvgPool2D_F32
,
AvgPool2D_F32_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_F16
,
AvgPool2D_F16_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_BF16
,
AvgPool2D_BF16_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_I8
,
AvgPool2D_I8_Types
);
TYPED_TEST_SUITE
(
AvgPool2D_F8
,
AvgPool2D_F8_Types
);
TYPED_TEST
(
AvgPool2D_F32
,
AvgPool2D_I8_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F16
,
AvgPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_BF16
,
AvgPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_I8
,
AvgPool2D_I8_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
AvgPool2D_F8
,
AvgPool2D_F8_Test
)
{
this
->
Run
();
}
test/pool/test_avg_pool3d_fwd.cpp
View file @
b924e330
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "profiler/profile_pool3d_fwd_impl.hpp"
#include "profiler/profile_pool3d_fwd_impl.hpp"
...
@@ -16,10 +16,19 @@ class TestAvgPool3dFwd : public ::testing::Test
...
@@ -16,10 +16,19 @@ class TestAvgPool3dFwd : public ::testing::Test
std
::
vector
<
PoolingParam
>
params
;
std
::
vector
<
PoolingParam
>
params
;
ck
::
profiler
::
PoolFwdInputParams
in_params_avg_pool
{
true
,
2
,
false
,
false
,
false
,
1
};
void
Run
()
void
Run
()
{
{
for
(
auto
param
:
params
)
for
(
auto
param
:
params
)
{
{
ck
::
profiler
::
PoolFwdKernelParams
kernel_params
{
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
};
bool
success
=
bool
success
=
ck
::
profiler
::
profile_pool3d_fwd_impl
<
InDataType
,
ck
::
profiler
::
profile_pool3d_fwd_impl
<
InDataType
,
OutDataType
,
OutDataType
,
...
@@ -29,26 +38,18 @@ class TestAvgPool3dFwd : public ::testing::Test
...
@@ -29,26 +38,18 @@ class TestAvgPool3dFwd : public ::testing::Test
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
ReduceTensorOp
::
AVG
,
ck
::
ReduceTensorOp
::
AVG
,
false
,
false
,
false
>
(
true
,
false
>
(
in_params_avg_pool
,
kernel_params
);
2
,
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
EXPECT_TRUE
(
success
);
}
}
}
}
};
};
#ifdef CK_ENABLE_FP16
using
KernelTypes
=
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
I32
,
I32
>
,
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
std
::
tuple
<
F8
,
F8
,
F32
,
I32
>
,
#else
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
std
::
tuple
<
BF16
,
BF16
,
F32
,
I32
>
,
#endif
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
TYPED_TEST_SUITE
(
TestAvgPool3dFwd
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestAvgPool3dFwd
,
KernelTypes
);
TYPED_TEST
(
TestAvgPool3dFwd
,
Test_Pool
)
TYPED_TEST
(
TestAvgPool3dFwd
,
Test_Pool
)
{
{
...
...
test/pool/test_max_pool2d_bwd.cpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_max_pool2d_bwd_impl.hpp"
#include "test_pool_fwd_common.hpp"
template
<
typename
T
>
class
MaxPool2dBWDTest
:
public
::
testing
::
Test
{
protected:
using
DOutDataType
=
std
::
tuple_element_t
<
0
,
T
>
;
using
DInDataType
=
std
::
tuple_element_t
<
1
,
T
>
;
using
IndexDataType
=
std
::
tuple_element_t
<
2
,
T
>
;
using
InDataType
=
DInDataType
;
using
OutDataType
=
DOutDataType
;
static
std
::
vector
<
PoolingParam
>
params
;
void
Run
()
{
for
(
auto
param
:
this
->
params
)
{
bool
success
=
ck
::
profiler
::
profile_max_pool2d_bwd_impl
<
InDataType
,
OutDataType
,
IndexDataType
,
DOutDataType
,
DInDataType
,
false
>
(
true
,
2
,
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
}
}
};
template
<
typename
T
>
std
::
vector
<
PoolingParam
>
MaxPool2dBWDTest
<
T
>::
params
=
{
{{
1
,
1
,
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
16
,
64
,
64
},
{
64
,
64
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
16
,
64
,
64
},
{
4
,
4
},
{
4
,
4
},
{
2
,
2
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
32
,
30
,
30
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}},
{{
2
,
2
,
30
,
30
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}}};
using
Max_Pool_2D_f32_types
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
I32
>>
;
using
Max_Pool_2D_int8_types
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
I32
>>
;
using
Max_Pool_2D_f16_types
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
I32
>>
;
using
Max_Pool_2D_bf16_types
=
::
testing
::
Types
<
std
::
tuple
<
BF16
,
BF16
,
I32
>>
;
using
Max_Pool_2D_f8_types
=
::
testing
::
Types
<
std
::
tuple
<
F8
,
F8
,
I32
>>
;
template
<
typename
TType
>
class
MaxPool2D_f32
:
public
MaxPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP32
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_f32 tests because CK_ENABLE_FP32 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_int8
:
public
MaxPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_INT8
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_int8 tests because CK_ENABLE_INT8 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_f16
:
public
MaxPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP16
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_f16 because CK_ENABLE_FP16 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_bf16
:
public
MaxPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_BF16
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_bf16 tests because CK_ENABLE_BF16 is not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_f8
:
public
MaxPool2dBWDTest
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP8
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_f8 tests because CK_ENABLE_FP8 is not enabled"
;
}
}
};
TYPED_TEST_SUITE
(
MaxPool2D_f32
,
Max_Pool_2D_f32_types
);
TYPED_TEST_SUITE
(
MaxPool2D_int8
,
Max_Pool_2D_int8_types
);
TYPED_TEST_SUITE
(
MaxPool2D_f16
,
Max_Pool_2D_f16_types
);
TYPED_TEST_SUITE
(
MaxPool2D_bf16
,
Max_Pool_2D_bf16_types
);
TYPED_TEST_SUITE
(
MaxPool2D_f8
,
Max_Pool_2D_f8_types
);
TYPED_TEST
(
MaxPool2D_f32
,
MaxPool2DTest_f32
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_int8
,
MaxPool2DTest_int8
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_f16
,
MaxPool2DTest_f16
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_bf16
,
MaxPool2DTest_bf16
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_f8
,
MaxPool2DTest_f8
)
{
this
->
Run
();
}
test/pool/test_max_pool2d_fwd.cpp
0 → 100644
View file @
b924e330
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_pool2d_fwd_impl.hpp"
#include "test_pool_fwd_common.hpp"
template
<
typename
Tuple
>
class
TestMaxPool2dFwd
:
public
::
testing
::
Test
{
protected:
using
InDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
OutDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
ComputeDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
IndexDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
static
constexpr
bool
ReturnIndex
=
std
::
tuple_element_t
<
4
,
Tuple
>::
value
;
static
std
::
vector
<
PoolingParam
>
params
;
void
Run
()
{
for
(
auto
param
:
params
)
{
// max pool
bool
success
=
ck
::
profiler
::
profile_pool2d_fwd_impl
<
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
ReturnIndex
>
(
true
,
2
,
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
}
}
};
template
<
typename
T
>
std
::
vector
<
PoolingParam
>
TestMaxPool2dFwd
<
T
>::
params
=
{
{{{
1
,
1
,
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
16
,
64
,
64
},
{
64
,
64
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
16
,
64
,
64
},
{
4
,
4
},
{
4
,
4
},
{
2
,
2
},
{
0
,
0
},
{
0
,
0
}},
{{
2
,
32
,
30
,
30
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}}}};
using
true_t
=
std
::
integral_constant
<
bool
,
true
>
;
using
false_t
=
std
::
integral_constant
<
bool
,
false
>
;
using
MaxPool2D_F32_Types
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
,
true_t
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
,
false_t
>>
;
using
MaxPool2D_F16_Types
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
,
true_t
>
,
std
::
tuple
<
F16
,
F16
,
F32
,
I32
,
false_t
>>
;
using
MaxPool2D_BF16_Types
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
F32
,
I32
,
true_t
>
,
std
::
tuple
<
BF16
,
BF16
,
F32
,
I32
,
false_t
>>
;
using
MaxPool2D_I8_Types
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
F32
,
I32
,
true_t
>
,
std
::
tuple
<
I8
,
I8
,
F32
,
I32
,
false_t
>>
;
using
MaxPool2D_F8_Types
=
::
testing
::
Types
<
std
::
tuple
<
F8
,
F8
,
F32
,
I32
,
true_t
>
,
std
::
tuple
<
F8
,
F8
,
F32
,
I32
,
false_t
>>
;
template
<
typename
TType
>
class
MaxPool2D_F32
:
public
TestMaxPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP32
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_F32 tests because CK_ENABLE_FP32 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_F16
:
public
TestMaxPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP16
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_F16 tests because CK_ENABLE_FP16 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_BF16
:
public
TestMaxPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_BF16
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_BF16 tests because CK_ENABLE_BF16 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_I8
:
public
TestMaxPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_INT8
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_I8 tests because CK_ENABLE_INT8 is "
"not enabled"
;
}
}
};
template
<
typename
TType
>
class
MaxPool2D_F8
:
public
TestMaxPool2dFwd
<
TType
>
{
protected:
void
SetUp
()
override
{
if
(
!
CK_ENABLE_FP8
)
{
GTEST_SKIP
()
<<
"Skipping MaxPool2D_F8 tests because CK_ENABLE_FP8 is "
"not enabled"
;
}
}
};
TYPED_TEST_SUITE
(
MaxPool2D_F32
,
MaxPool2D_F32_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_F16
,
MaxPool2D_F16_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_BF16
,
MaxPool2D_BF16_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_I8
,
MaxPool2D_I8_Types
);
TYPED_TEST_SUITE
(
MaxPool2D_F8
,
MaxPool2D_F8_Types
);
TYPED_TEST
(
MaxPool2D_F32
,
MaxPool2D_I8_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F16
,
MaxPool2D_F16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_BF16
,
MaxPool2D_BF16_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_I8
,
MaxPool2D_I8_Test
)
{
this
->
Run
();
}
TYPED_TEST
(
MaxPool2D_F8
,
MaxPool2D_F8_Test
)
{
this
->
Run
();
}
test/pool/test_max_pool3d_fwd.cpp
View file @
b924e330
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "profiler/profile_pool3d_fwd_impl.hpp"
#include "profiler/profile_pool3d_fwd_impl.hpp"
...
@@ -16,10 +16,20 @@ class TestMaxPool3dFwd : public ::testing::Test
...
@@ -16,10 +16,20 @@ class TestMaxPool3dFwd : public ::testing::Test
std
::
vector
<
PoolingParam
>
params
;
std
::
vector
<
PoolingParam
>
params
;
ck
::
profiler
::
PoolFwdInputParams
in_params_max_pool
{
true
,
2
,
false
,
false
,
false
,
0
};
ck
::
profiler
::
PoolFwdInputParams
in_params_max_pool_indexed
{
true
,
2
,
false
,
false
,
true
,
0
};
void
Run
()
void
Run
()
{
{
for
(
auto
param
:
params
)
for
(
auto
param
:
params
)
{
{
ck
::
profiler
::
PoolFwdKernelParams
kernel_params
{
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
};
// max pool
// max pool
bool
success
=
bool
success
=
ck
::
profiler
::
profile_pool3d_fwd_impl
<
InDataType
,
ck
::
profiler
::
profile_pool3d_fwd_impl
<
InDataType
,
...
@@ -30,16 +40,7 @@ class TestMaxPool3dFwd : public ::testing::Test
...
@@ -30,16 +40,7 @@ class TestMaxPool3dFwd : public ::testing::Test
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
,
false
>
(
true
,
false
>
(
in_params_max_pool
,
kernel_params
);
2
,
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
EXPECT_TRUE
(
success
);
// max pool + index
// max pool + index
...
@@ -51,27 +52,18 @@ class TestMaxPool3dFwd : public ::testing::Test
...
@@ -51,27 +52,18 @@ class TestMaxPool3dFwd : public ::testing::Test
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
ReduceTensorOp
::
MAX
,
ck
::
ReduceTensorOp
::
MAX
,
false
,
false
,
true
>
(
true
,
true
>
(
in_params_max_pool_indexed
,
2
,
kernel_params
);
false
,
false
,
param
.
length_
,
param
.
window_spatial_lengths_
,
param
.
window_strides_
,
param
.
window_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
);
EXPECT_TRUE
(
success
);
EXPECT_TRUE
(
success
);
}
}
}
}
};
};
#ifdef CK_ENABLE_FP16
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I8
,
I8
,
I8
,
I32
>
,
using
KernelTypes
=
std
::
tuple
<
F8
,
F8
,
F8
,
I32
>
,
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
std
::
tuple
<
F16
,
F16
,
F16
,
I32
>
,
#else
std
::
tuple
<
BF16
,
BF16
,
BF16
,
I32
>
,
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#endif
TYPED_TEST_SUITE
(
TestMaxPool3dFwd
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestMaxPool3dFwd
,
KernelTypes
);
TYPED_TEST
(
TestMaxPool3dFwd
,
Test_Pool
)
TYPED_TEST
(
TestMaxPool3dFwd
,
Test_Pool
)
...
...
test/pool/test_pool_fwd_common.hpp
View file @
b924e330
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "ck/ck.hpp"
using
I8
=
int8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
using
I8
=
int8_t
;
using
F8
=
ck
::
f8_t
;
using
ck
::
index_t
;
using
ck
::
index_t
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
struct
PoolingParam
struct
PoolingParam
{
{
...
...
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment