Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a01ca8e6
Commit
a01ca8e6
authored
May 12, 2023
by
Adam Osewski
Browse files
Add gtests for gemm splitk using ckProfiler API.
parent
e9fd26c8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
157 additions
and
264 deletions
+157
-264
test/gemm_split_k/CMakeLists.txt
test/gemm_split_k/CMakeLists.txt
+2
-3
test/gemm_split_k/gemm_split_k.cpp
test/gemm_split_k/gemm_split_k.cpp
+0
-261
test/gemm_split_k/test_gemm_splitk.cpp
test/gemm_split_k/test_gemm_splitk.cpp
+62
-0
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
+15
-0
test/gemm_split_k/test_gemm_splitk_util.hpp
test/gemm_split_k/test_gemm_splitk_util.hpp
+78
-0
No files found.
test/gemm_split_k/CMakeLists.txt
View file @
a01ca8e6
add_test_executable
(
test_gemm_split_k gemm_split_k.cpp
)
add_gtest_executable
(
test_gemm_splitk test_gemm_splitk.cpp
)
target_link_libraries
(
test_gemm_split_k PRIVATE utility
)
target_link_libraries
(
test_gemm_splitk PRIVATE utility device_gemm_splitk_instance
)
target_link_libraries
(
test_gemm_split_k PRIVATE device_gemm_splitk_instance
)
test/gemm_split_k/gemm_split_k.cpp
deleted
100644 → 0
View file @
e9fd26c8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/host_gemm.hpp"
enum
struct
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
float
max_diff
=
1e-6
;
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
{
return
false
;
}
}
return
true
;
}
struct
gemmArgs
{
GemmMatrixLayout
layout
;
int
M
;
int
N
;
int
K
;
int
StrideA
;
int
StrideB
;
int
StrideC
;
int
KBatch
;
};
int
test_gemm
(
const
gemmArgs
&
args
)
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
bool
a_row_major
,
b_row_major
,
c_row_major
;
switch
(
args
.
layout
)
{
case
GemmMatrixLayout
::
MK_KN_MN
:
a_row_major
=
true
;
b_row_major
=
true
;
c_row_major
=
true
;
break
;
case
GemmMatrixLayout
::
MK_NK_MN
:
a_row_major
=
true
;
b_row_major
=
false
;
c_row_major
=
true
;
break
;
case
GemmMatrixLayout
::
KM_KN_MN
:
a_row_major
=
false
;
b_row_major
=
true
;
c_row_major
=
true
;
break
;
case
GemmMatrixLayout
::
KM_NK_MN
:
a_row_major
=
false
;
b_row_major
=
false
;
c_row_major
=
true
;
break
;
default:
printf
(
"not supported layout"
);
return
1
;
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool
row_major
)
{
using
namespace
ck
::
literals
;
if
(
row_major
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
Tensor
<
float
>
a_m_k
(
f_host_tensor_descriptor
(
args
.
M
,
args
.
K
,
args
.
StrideA
,
a_row_major
));
Tensor
<
float
>
b_k_n
(
f_host_tensor_descriptor
(
args
.
K
,
args
.
N
,
args
.
StrideB
,
b_row_major
));
Tensor
<
float
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
args
.
M
,
args
.
N
,
args
.
StrideC
,
c_row_major
));
Tensor
<
float
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
args
.
M
,
args
.
N
,
args
.
StrideC
,
c_row_major
));
// init data
std
::
size_t
num_thread
=
1
;
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
float
>
{
-
5
,
5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
float
>
{
-
5
,
5
},
num_thread
);
// set zero to c_device_buf
c_m_n_device_result
.
GenerateTensorValue
(
GeneratorTensor_0
<
float
>
{},
num_thread
);
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
DeviceMem
a_device_buf
(
sizeof
(
float
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
float
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
float
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
success
=
false
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmSplitK
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
float
,
float
,
float
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemm_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
float
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
c_device_buf
.
GetDeviceBuffer
()),
args
.
M
,
args
.
N
,
args
.
K
,
args
.
StrideA
,
args
.
StrideB
,
args
.
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
args
.
KBatch
);
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
());
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
!
check_out
(
c_m_n_host_result
,
c_m_n_device_result
))
{
success
=
false
;
break
;
}
success
=
true
;
}
}
return
success
;
};
bool
success
=
false
;
if
(
args
.
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
success
=
test
(
Row
{},
Row
{},
Row
{});
}
else
if
(
args
.
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
success
=
test
(
Row
{},
Col
{},
Row
{});
}
else
if
(
args
.
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
success
=
test
(
Col
{},
Row
{},
Row
{});
}
else
{
success
=
test
(
Col
{},
Col
{},
Row
{});
}
auto
error_code
=
0
;
if
(
success
)
{
std
::
cout
<<
"test split k : Pass"
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"test split k: Fail "
<<
std
::
endl
;
error_code
=
-
1
;
// test needs to report failure
}
return
error_code
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
std
::
vector
<
gemmArgs
>
test_cases
;
if
(
argc
==
1
)
{
test_cases
=
{{
GemmMatrixLayout
::
MK_KN_MN
,
1024
,
1024
,
1024
,
1024
,
1024
,
1024
,
2
},
{
GemmMatrixLayout
::
MK_KN_MN
,
1024
,
1024
,
1024
,
1024
,
1024
,
1024
,
8
}};
}
else
if
(
argc
==
9
)
{
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
int
M
=
std
::
stoi
(
argv
[
2
]);
const
int
N
=
std
::
stoi
(
argv
[
3
]);
const
int
K
=
std
::
stoi
(
argv
[
4
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
5
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
6
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
7
]);
const
int
KBatch
=
std
::
stoi
(
argv
[
8
]);
test_cases
=
{{
layout
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
}};
}
else
{
printf
(
"arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, m] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, m] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch
\n
"
);
return
-
1
;
}
bool
error
=
false
;
for
(
const
auto
&
kinder
:
test_cases
)
{
error
|=
test_gemm
(
kinder
);
}
return
error
?
1
:
0
;
}
test/gemm_split_k/test_gemm_splitk.cpp
0 → 100644
View file @
a01ca8e6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// #include <algorithm>
// #include <stdexcept>
#include <vector>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_splitk_util.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
namespace
{
template
<
typename
X
,
typename
Y
>
struct
tuple_concat
;
template
<
typename
...
Xs
,
typename
...
Ys
>
struct
tuple_concat
<
std
::
tuple
<
Xs
...
>
,
std
::
tuple
<
Ys
...
>>
{
using
type
=
std
::
tuple
<
Xs
...,
Ys
...
>
;
};
}
// namespace
template
<
typename
Tuple
>
class
TestGemmSplitK_MK_KN
:
public
ck
::
test
::
TestGemmSplitK
<
typename
tuple_concat
<
std
::
tuple
<
Row
,
Row
>
,
Tuple
>::
type
>
{
};
// template <typename Tuple>
// class TestGemmSplitK_MK_NK : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Row, Col>,
// Tuple>::type> {};
// template <typename Tuple>
// class TestGemmSplitK_KM_KN : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Col, Row>,
// Tuple>::type> {};
// template <typename Tuple>
// class TestGemmSplitK_KM_NK : public ck::test::TestGemmSplitK<tuple_concat<std::tuple<Col, Col>,
// Tuple>::type> {};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ADataType, BDataType, CDataType
std
::
tuple
<
F16
,
F16
,
F16
>
,
std
::
tuple
<
F32
,
F32
,
F32
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestGemmSplitK_MK_KN
,
KernelTypes
);
// TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
// TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
#include "test_gemm_splitk_ut_cases.inc"
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
0 → 100644
View file @
a01ca8e6
#pragma once
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
int
N
=
512
;
int
K
=
320
;
int
StrideA
=
K
;
int
StrideB
=
N
;
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
test/gemm_split_k/test_gemm_splitk_util.hpp
0 → 100644
View file @
a01ca8e6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_gemm_splitk_impl.hpp"
namespace
ck
{
namespace
test
{
template
<
typename
Tuple
>
class
TestGemmSplitK
:
public
testing
::
Test
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
F32
=
float
;
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
Row
;
using
ADataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
public:
bool
verify_
=
true
;
int
init_method_
=
1
;
// decimal value initialization
bool
log_
=
false
;
bool
bench_
=
false
;
// measure kernel performance
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
,
3
,
5
,
8
};
}
void
Run
(
const
int
M
,
const
int
N
,
const
int
K
,
const
int
StrideA
,
const
int
StrideB
,
const
int
StrideC
)
{
for
(
auto
kb
:
k_batches_
)
{
RunSingle
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kb
);
}
}
void
RunSingle
(
const
int
M
,
const
int
N
,
const
int
K
,
const
int
StrideA
,
const
int
StrideB
,
const
int
StrideC
,
int
kbatch
=
1
)
{
bool
pass
=
ck
::
profiler
::
profile_gemm_splitk_impl
<
ADataType
,
BDataType
,
F32
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
kbatch
);
EXPECT_TRUE
(
pass
);
}
};
}
// namespace test
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment