Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
832b69cb
"include/vscode:/vscode.git/clone" did not exist on "fd76c787211a131fcb6790aff48faead4459a3fa"
Unverified
Commit
832b69cb
authored
Jun 14, 2023
by
zjing14
Committed by
GitHub
Jun 14, 2023
Browse files
Merge branch 'develop' into grouped_conv_3d_layout_fix
parents
53130727
a35456a3
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
0 deletions
+74
-0
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
+74
-0
No files found.
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
0 → 100644
View file @
832b69cb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp"
namespace
{
using
F16
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
template
<
typename
Tuple
>
class
TestBatchedGemmMultiD
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
static
constexpr
int
M
=
512
;
static
constexpr
int
N
=
256
;
static
constexpr
int
K
=
128
;
static
constexpr
int
BatchCount
=
3
;
template
<
typename
DataType
>
void
Run
()
{
using
namespace
ck
::
tensor_operation
::
device
;
const
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
ALayout
,
BLayout
,
Empty_Tuple
,
CLayout
,
DataType
,
DataType
,
Empty_Tuple
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
EXPECT_TRUE
(
pass
);
}
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
Row
,
Row
,
Row
>
,
std
::
tuple
<
Row
,
Col
,
Row
>
,
std
::
tuple
<
Col
,
Row
,
Row
>
,
std
::
tuple
<
Col
,
Col
,
Row
>>
;
}
// namespace
TYPED_TEST_SUITE
(
TestBatchedGemmMultiD
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmMultiD
,
f16
)
{
this
->
template
Run
<
F16
>();
}
TYPED_TEST
(
TestBatchedGemmMultiD
,
int8
)
{
this
->
template
Run
<
int8_t
>();
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment