Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4fec5ad3
Commit
4fec5ad3
authored
Oct 28, 2022
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into wmma_op
parents
24faa1fc
87fd1152
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
925 additions
and
377 deletions
+925
-377
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
...m_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
+33
-30
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp
...m_permute/test_batched_gemm_softmax_gemm_permute_util.hpp
+76
-43
test/gemm/CMakeLists.txt
test/gemm/CMakeLists.txt
+10
-0
test/gemm/gemm_bf16.cpp
test/gemm/gemm_bf16.cpp
+6
-51
test/gemm/gemm_fp16.cpp
test/gemm/gemm_fp16.cpp
+6
-51
test/gemm/gemm_fp32.cpp
test/gemm/gemm_fp32.cpp
+6
-51
test/gemm/gemm_fp64.cpp
test/gemm/gemm_fp64.cpp
+6
-51
test/gemm/gemm_int8.cpp
test/gemm/gemm_int8.cpp
+6
-51
test/gemm/gemm_standalone_xdl_fp16.cpp
test/gemm/gemm_standalone_xdl_fp16.cpp
+162
-0
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+65
-42
test/gemm/instance/gemm_f16_nn_instance.cpp
test/gemm/instance/gemm_f16_nn_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_nn_instance.hpp
test/gemm/instance/gemm_f16_nn_instance.hpp
+41
-0
test/gemm/instance/gemm_f16_nt_instance.cpp
test/gemm/instance/gemm_f16_nt_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_nt_instance.hpp
test/gemm/instance/gemm_f16_nt_instance.hpp
+41
-0
test/gemm/instance/gemm_f16_tn_instance.cpp
test/gemm/instance/gemm_f16_tn_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_tn_instance.hpp
test/gemm/instance/gemm_f16_tn_instance.hpp
+41
-0
test/gemm/instance/gemm_f16_tt_instance.cpp
test/gemm/instance/gemm_f16_tt_instance.cpp
+86
-0
test/gemm/instance/gemm_f16_tt_instance.hpp
test/gemm/instance/gemm_f16_tt_instance.hpp
+41
-0
test/gemm/run_gemm_test.inc
test/gemm/run_gemm_test.inc
+41
-0
test/normalization/test_groupnorm_fp16.cpp
test/normalization/test_groupnorm_fp16.cpp
+0
-7
No files found.
test/batched_gemm_
masking_scale_
softmax_gemm_permute/test_batched_gemm_
masking_scale_
softmax_gemm_permute_fp16.cpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
View file @
4fec5ad3
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_
masking_scale_
softmax_gemm_permute_util.hpp"
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
template
<
typename
Tuple
>
class
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
...
...
@@ -10,13 +10,18 @@ class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
{
};
using
I1_t
=
ck
::
Number
<
1
>
;
using
I2_t
=
ck
::
Number
<
2
>
;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskOutUpperTriangle_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// clang-format off
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskOutUpperTriangle_t
>
>
;
// clang-format on
...
...
@@ -91,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Bench_FP16_IrregularK
)
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
DISABLED_
Bench_FP16_IrregularK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{{
256
,
256
,
160
,
160
,
1
,
16
},
{
256
,
64
,
160
,
64
,
1
,
16
},
...
...
@@ -125,7 +130,6 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP1
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
...
...
@@ -133,22 +137,22 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
...
...
@@ -156,13 +160,13 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
//
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
EXPECT_FALSE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
//
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
//
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
EXPECT_FALSE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
130
,
128
));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
//
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
EXPECT_FALSE
(
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
// clang-format on
}
...
...
@@ -174,6 +178,5 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{
1020
,
1020
,
64
,
128
,
4
,
6
},
{
576
,
576
,
64
,
64
,
4
,
6
},
};
this
->
bench_
=
true
;
this
->
Run
();
}
test/batched_gemm_
masking_scale_
softmax_gemm_permute/test_batched_gemm_
masking_scale_
softmax_gemm_permute_util.hpp
→
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp
View file @
4fec5ad3
...
...
@@ -4,10 +4,14 @@
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
MaskingSpecialization
;
using
ck
::
tensor_operation
::
device
::
TensorSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
...
...
@@ -20,14 +24,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
typename
Tuple
>
struct
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
:
public
::
testing
::
Test
{
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
B0Layout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CPermuteNumDims_G_M_O
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
using
NumDimGType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
NumDimMType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
NumDimNType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
NumDimKType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
NumDimOType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
8
,
Tuple
>
;
using
Acc0BiasDataType
=
std
::
tuple_element_t
<
9
,
Tuple
>
;
using
Acc1BiasDataType
=
std
::
tuple_element_t
<
10
,
Tuple
>
;
using
MaskingType
=
std
::
tuple_element_t
<
11
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
6
,
4
},
...
...
@@ -42,15 +50,20 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
<
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_O
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_softmax_gemm_permute_impl
<
NumDimGType
::
value
,
NumDimMType
::
value
,
NumDimNType
::
value
,
NumDimKType
::
value
,
NumDimOType
::
value
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskingType
::
value
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
EXPECT_TRUE
(
pass
);
}
...
...
@@ -72,19 +85,13 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
struct
DeviceInstanceWrapper_
G2M1N1K1O1_
TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
...
...
@@ -103,14 +110,17 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_O
,
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
...
...
@@ -119,6 +129,10 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecialization
::
Default
,
// ATensorSpec
TensorSpecialization
::
Default
,
// B0TensorSpec
TensorSpecialization
::
Default
,
// B1TensorSpec
TensorSpecialization
::
Default
,
// CTensorSpec
1
,
256
,
128
,
// MPerBlock
...
...
@@ -159,29 +173,48 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
true
>
;
// Masking
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// MaskOutUpperTriangle
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
const
int
G0
=
1
,
G1
=
1
;
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B0 layout [G0, N, G1, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B1 layout [G0, N, G1, O]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
};
// C layout [G0, M, G1, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
{
0
,
0
,
M
,
O
},
// gs ms ns lengths
{
0
,
O
,
0
,
1
},
// gs ms ns strides
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
{},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
{},
// acc0_biases_gs_ms_ns_lengths
{},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
Scale
{
1.
f
},
// acc0_element_op
...
...
test/gemm/CMakeLists.txt
View file @
4fec5ad3
...
...
@@ -13,3 +13,13 @@ target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable
(
test_gemm_int8 gemm_int8.cpp
)
target_link_libraries
(
test_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
add_library
(
gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp
)
add_test_executable
(
test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp
)
target_link_libraries
(
test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility
)
target_include_directories
(
test_gemm_standalone_xdl_fp16 PRIVATE instance/
)
test/gemm/gemm_bf16.cpp
View file @
4fec5ad3
...
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
#include "run_gemm_test.inc"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
int
main
()
{
return
run_gemm_test
();
}
test/gemm/gemm_fp16.cpp
View file @
4fec5ad3
...
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
#include "run_gemm_test.inc"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
int
main
()
{
return
run_gemm_test
();
}
test/gemm/gemm_fp32.cpp
View file @
4fec5ad3
...
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
#include "run_gemm_test.inc"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
int
main
()
{
return
run_gemm_test
();
}
test/gemm/gemm_fp64.cpp
View file @
4fec5ad3
...
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
double
;
using
BDataType
=
double
;
using
CDataType
=
double
;
using
AccDataType
=
double
;
using
ADataType
=
double
;
using
BDataType
=
double
;
using
CDataType
=
double
;
using
AccDataType
=
double
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
#include "run_gemm_test.inc"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
int
main
()
{
return
run_gemm_test
();
}
test/gemm/gemm_int8.cpp
View file @
4fec5ad3
...
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
#include "run_gemm_test.inc"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
int
main
()
{
return
run_gemm_test
();
}
test/gemm/gemm_standalone_xdl_fp16.cpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "gemm_f16_nn_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
F16
=
ck
::
half_t
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
float
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
ck
::
gemm_util
::
GemmParams
;
using
ck
::
tensor_operation
::
device
::
BaseOperator
;
using
ck
::
tensor_operation
::
device
::
DeviceGemm
;
using
namespace
ck
::
tensor_operation
::
device
::
instance
;
using
DeviceGemmNN
=
DeviceGemm
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmNT
=
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmTN
=
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
DeviceGemmTT
=
DeviceGemm
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>
;
struct
LayoutConfig
{
bool
ARowMajor
;
bool
BRowMajor
;
bool
CRowMajor
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using
OpFactoryFn
=
void
(
*
)(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
);
std
::
vector
<
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>>
problems
=
{
// clang-format off
// 104 tiles
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// 110 tiles
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2560
,
2816
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2560
,
1408
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
{
GemmParams
{
1280
,
1408
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
{
GemmParams
{
1280
,
704
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// clang-format on
};
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
if
(
argc
==
1
)
{
// use default
}
else
if
(
argc
==
3
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
time_kernel
=
std
::
stoi
(
argv
[
2
]);
}
else
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
<<
"arg2: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
return
0
;
}
bool
pass
=
true
;
for
(
auto
&
p
:
problems
)
{
GemmParams
&
problem_size
=
std
::
get
<
0
>
(
p
);
const
LayoutConfig
&
layout_config
=
std
::
get
<
1
>
(
p
);
const
auto
&
factory
=
std
::
get
<
2
>
(
p
);
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>
ops
;
factory
(
ops
);
// overwrite strides
problem_size
.
StrideA
=
layout_config
.
ARowMajor
?
problem_size
.
K
:
problem_size
.
M
;
problem_size
.
StrideB
=
layout_config
.
BRowMajor
?
problem_size
.
N
:
problem_size
.
K
;
problem_size
.
StrideC
=
layout_config
.
CRowMajor
?
problem_size
.
N
:
problem_size
.
M
;
if
(
!
layout_config
.
ARowMajor
&&
!
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmNN
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
!
layout_config
.
ARowMajor
&&
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmNT
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
layout_config
.
ARowMajor
&&
!
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmTN
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
else
if
(
layout_config
.
ARowMajor
&&
layout_config
.
BRowMajor
)
{
auto
op_ptr
=
dynamic_cast
<
DeviceGemmTT
*>
(
ops
[
0
].
get
());
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
op_ptr
,
problem_size
,
do_verification
,
time_kernel
);
}
}
std
::
cout
<<
(
pass
?
"ALL TESTS PASSED"
:
"SOME TESTS FAILED"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_util.hpp
View file @
4fec5ad3
...
...
@@ -16,21 +16,13 @@ namespace gemm_util {
struct
GemmParams
{
GemmParams
()
:
M
(
1024
),
N
(
1024
),
K
(
1024
),
StrideA
(
1024
),
StrideB
(
1024
),
StrideC
(
1024
),
alpha
(
1
),
beta
(
0
)
{
}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
float
alpha
;
float
beta
;
ck
::
index_t
StrideA
=
1024
;
ck
::
index_t
StrideB
=
1024
;
ck
::
index_t
StrideC
=
1024
;
};
template
<
typename
GemmInstance
,
...
...
@@ -69,7 +61,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
bool
time_kernel
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
...
...
@@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
{
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
params
.
M
*
params
.
N
*
params
.
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
params
.
M
*
params
.
K
+
sizeof
(
BDataType
)
*
params
.
K
*
params
.
N
+
sizeof
(
CDataType
)
*
params
.
M
*
params
.
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
...
...
@@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
}
}
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
template
<
typename
AccDataType
>
struct
TestGemm
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
...
...
@@ -156,25 +158,42 @@ struct TestGemm
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceGemmPtr_
&
gemmPtr
)
template
<
template
<
class
...
>
class
DeviceGemmPtr_
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
auto
operator
()(
DeviceGemmPtr_
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>*
gemmPtr
,
const
GemmParams
&
params
=
GemmParams
{},
bool
do_verification
=
true
,
bool
time_kernel
=
false
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
// Arrange
ck
::
gemm_util
::
GemmParams
params
;
params
.
M
=
1024
;
params
.
N
=
1024
;
params
.
K
=
1024
;
params
.
StrideA
=
1024
;
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
auto
host_tensors
=
PrepareGemmTensor
<
ADataType
,
BDataType
,
CDataType
,
ALayout
,
BLayout
,
CLayout
>
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
...
...
@@ -193,14 +212,18 @@ struct TestGemm
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
do_verification
)
{
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// Act
bool
is_supported
=
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
,
time_kernel
);
if
(
is_supported
)
if
(
is_supported
&&
do_verification
)
{
// Assert
bool
res
=
false
;
...
...
test/gemm/instance/gemm_f16_nn_instance.cpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_nn_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_nn_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
2
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nn_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_nn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_256x256
{});
}
void
add_gemm_f16_nn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_256x128
{});
}
void
add_gemm_f16_nn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_128x128
{});
}
void
add_gemm_f16_nn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nn_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nn_instance.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_nn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nt_instance.cpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_nt_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
2
,
2
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nt_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nt_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_nt_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_nt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_256x256
{});
}
void
add_gemm_f16_nt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_256x128
{});
}
void
add_gemm_f16_nt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_128x128
{});
}
void
add_gemm_f16_nt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_nt_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_nt_instance.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_nt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_nt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tn_instance.cpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_tn_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tn_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tn_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tn_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_tn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_256x256
{});
}
void
add_gemm_f16_tn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_256x128
{});
}
void
add_gemm_f16_tn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_128x128
{});
}
void
add_gemm_f16_tn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tn_instance.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_tn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tt_instance.cpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_tt_256x256
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
8
,
2
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tt_256x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tt_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
using
gemm_f16_tt_128x64
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_gemm_f16_tt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_256x256
{});
}
void
add_gemm_f16_tt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_256x128
{});
}
void
add_gemm_f16_tt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_128x128
{});
}
void
add_gemm_f16_tt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tt_128x64
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/instance/gemm_f16_tt_instance.hpp
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
void
add_gemm_f16_tt_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tt_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tt_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
void
add_gemm_f16_tt_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
);
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm/run_gemm_test.inc
0 → 100644
View file @
4fec5ad3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int
run_gemm_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
gemmPtr
.
get
());
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/normalization/test_groupnorm_fp16.cpp
View file @
4fec5ad3
...
...
@@ -45,13 +45,6 @@ class TestGroupnorm : public ::testing::Test
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment