Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8aa44bcd
Commit
8aa44bcd
authored
Aug 11, 2022
by
Anthony Chang
Browse files
add gemm_gemm instances and tests
parent
51fc99a8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
595 additions
and
0 deletions
+595
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
...brary/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
+93
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
...r_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
+8
-0
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+67
-0
profiler/include/profile_batched_gemm_gemm_impl.hpp
profiler/include/profile_batched_gemm_gemm_impl.hpp
+313
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/batched_gemm_gemm/CMakeLists.txt
test/batched_gemm_gemm/CMakeLists.txt
+5
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+39
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
+68
-0
No files found.
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp
0 → 100644
View file @
8aa44bcd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>
{
using
DeviceOp
=
DeviceBatchedGemmGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
8aa44bcd
...
@@ -13,6 +13,7 @@ add_subdirectory(gemm_reduce)
...
@@ -13,6 +13,7 @@ add_subdirectory(gemm_reduce)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
contraction_scale
)
add_subdirectory
(
contraction_scale
)
add_subdirectory
(
contraction_bilinear
)
add_subdirectory
(
contraction_bilinear
)
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt
0 → 100644
View file @
8aa44bcd
set
(
DEVICE_BATCHED_GEMM_GEMM_INSTANCE_SOURCE
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
add_instance_library
(
device_batched_gemm_gemm_instance OBJECT
${
DEVICE_BATCHED_GEMM_GEMM_INSTANCE_SOURCE
}
)
target_compile_features
(
device_batched_gemm_gemm_instance PUBLIC
)
set_target_properties
(
device_batched_gemm_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
clang_tidy_check
(
device_batched_gemm_gemm_instance
)
\ No newline at end of file
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
0 → 100644
View file @
8aa44bcd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profile_batched_gemm_gemm_impl.hpp
0 → 100644
View file @
8aa44bcd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace
ck
{
namespace
profiler
{
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
>
bool
profile_batched_gemm_gemm_impl
(
bool
do_verification
,
int
init_method
,
bool
do_log
,
bool
time_kernel
,
int
M
,
int
N
,
int
K
,
int
O
,
int
BatchCount
=
1
,
int
StrideA
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB1
=
-
1
,
int
StrideC
=
-
1
,
int
BatchStrideA
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideC
=
-
1
)
{
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
AccDataType
=
float
;
// Ref Gemm0
using
ReferenceGemm0Instance
=
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
ADataType
,
AccDataType
,
AElementOp
,
B0ElementOp
,
CElementOp
>
;
// Ref Gemm
using
ReferenceGemm1Instance
=
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
bool
pass
=
true
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
O
:
M
;
StrideA
=
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
;
StrideB0
=
(
StrideB0
<
0
)
?
DefaultStrideB0
:
StrideB0
;
StrideB1
=
(
StrideB1
<
0
)
?
DefaultStrideB1
:
StrideB1
;
StrideC
=
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
;
const
int
DefaultBatchStrideA
=
(
ck
::
is_same_v
<
ALayout
,
Col
>
?
K
:
M
)
*
StrideA
;
const
int
DefaultBatchStrideB0
=
(
ck
::
is_same_v
<
B0Layout
,
Col
>
?
N
:
K
)
*
StrideB0
;
const
int
DefaultBatchStrideB1
=
(
ck
::
is_same_v
<
B1Layout
,
Col
>
?
O
:
N
)
*
StrideB1
;
const
int
DefaultBatchStrideC
=
(
ck
::
is_same_v
<
CLayout
,
Col
>
?
O
:
M
)
*
StrideC
;
BatchStrideA
=
BatchStrideA
<
0
?
DefaultBatchStrideA
:
BatchStrideA
;
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
Row
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
batch_stride
,
1
,
stride
}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
BatchStrideA
,
ALayout
{}));
Tensor
<
B0DataType
>
b0_g_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_g_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
Tensor
<
CDataType
>
c_g_m_o_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_g_m_o_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor
<
ADataType
>
acc0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
3
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
3
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
3
});
break
;
case
2
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
0.0
,
1.0
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_3
<
B1DataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSize
());
DeviceMem
c_g_m_o_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_o_device_result
.
mDesc
.
GetElementSize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b1_g_n_o_device_buf
.
ToDevice
(
b1_g_n_o
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
>
;
// get device op instances
const
auto
op_ptrs
=
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
if
(
do_verification
)
{
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
PassThrough
{});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
auto
ref_gemm1
=
ReferenceGemm1Instance
{};
auto
ref_gemm1_invoker
=
ref_gemm1
.
MakeInvoker
();
auto
ref_gemm1_argument
=
ref_gemm1
.
MakeArgument
(
acc0_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
}
std
::
string
best_op_name
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device op instances
for
(
auto
&
op_ptr
:
op_ptrs
)
{
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
BatchCount
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b0_element_op
,
acc0_element_op
,
b1_element_op
,
c_element_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_g_m_o_host_result
.
mData
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a_g_m_k: "
,
a_g_m_k
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b0_g_k_n : "
,
b0_g_k_n
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b1_g_n_o : "
,
b1_g_n_o
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_g_m_o_host_result : "
,
c_g_m_o_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_g_m_o_device_result : "
,
c_g_m_o_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
else
{
std
::
cout
<<
op_ptr
->
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
return
pass
;
}
}
// namespace profiler
}
// namespace ck
test/CMakeLists.txt
View file @
8aa44bcd
...
@@ -40,6 +40,7 @@ add_subdirectory(gemm_split_k)
...
@@ -40,6 +40,7 @@ add_subdirectory(gemm_split_k)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
reduce
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
convnd_fwd
)
...
...
test/batched_gemm_gemm/CMakeLists.txt
0 → 100644
View file @
8aa44bcd
add_custom_target
(
test_batched_gemm_gemm
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
\ No newline at end of file
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
0 → 100644
View file @
8aa44bcd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_gemm_util.hpp"
template
<
typename
Tuple
>
class
TestBatchedGemmGemmFP16
:
public
TestBatchedGemmGemm
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestBatchedGemmGemmFP16
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
256
,
256
,
64
,
64
,
768
},
{
256
,
256
,
128
,
128
,
768
},
{
512
,
512
,
64
,
64
,
768
},
{
512
,
512
,
128
,
128
,
768
},
{
1024
,
1024
,
64
,
64
,
768
},
{
1024
,
1024
,
128
,
128
,
768
},
{
2048
,
2048
,
64
,
64
,
768
},
{
2048
,
2048
,
128
,
128
,
768
},
{
4096
,
4096
,
64
,
64
,
768
},
{
4096
,
4096
,
128
,
128
,
768
},
};
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
Run
();
}
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
0 → 100644
View file @
8aa44bcd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
using
F16
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
typename
Tuple
>
struct
TestBatchedGemmGemm
:
public
::
testing
::
Test
{
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
B0Layout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
};
bool
bench_
=
false
;
bool
verify_
=
true
;
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
BatchCount
)
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_gemm_impl
<
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ALayout
,
B0Layout
,
B1Layout
,
CLayout
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
BatchCount
);
EXPECT_TRUE
(
pass
);
}
void
Run
()
{
for
(
auto
lengths
:
this
->
lengths_
)
{
int
M
=
lengths
[
0
];
int
N
=
lengths
[
1
];
int
K
=
lengths
[
2
];
int
O
=
lengths
[
3
];
int
BatchCount
=
lengths
[
4
];
this
->
RunSingle
(
M
,
N
,
K
,
O
,
BatchCount
);
}
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment