Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
55b19fce
Commit
55b19fce
authored
Aug 22, 2022
by
Anthony Chang
Browse files
test gemm_gemm padding
parent
cbb9be8b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
245 additions
and
2 deletions
+245
-2
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+2
-1
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+4
-1
profiler/include/profile_batched_gemm_gemm_impl.hpp
profiler/include/profile_batched_gemm_gemm_impl.hpp
+6
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+112
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
+121
-0
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
55b19fce
...
...
@@ -700,7 +700,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
<<
B1K1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
55b19fce
...
...
@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
...
...
@@ -37,7 +38,9 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
profiler/include/profile_batched_gemm_gemm_impl.hpp
View file @
55b19fce
...
...
@@ -195,6 +195,12 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// early fail when no instances are found
if
(
op_ptrs
.
size
()
==
0
)
{
return
false
;
}
if
(
do_verification
)
{
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
...
...
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
View file @
55b19fce
...
...
@@ -19,6 +19,72 @@ TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
136
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
129
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
...
...
@@ -37,3 +103,49 @@ TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
129
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMismatch
)
{
int
P
=
129
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
View file @
55b19fce
...
...
@@ -4,8 +4,12 @@
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
...
...
@@ -66,3 +70,120 @@ struct TestBatchedGemmGemm : public ::testing::Test
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment