Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
200dd06b
Commit
200dd06b
authored
Sep 16, 2022
by
wangshaojie6
Browse files
add test file
parent
6e0a93d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
35 deletions
+41
-35
test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
...ed_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
+3
-3
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
..._batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
+0
-0
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
+38
-32
No files found.
test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
View file @
200dd06b
add_custom_target
(
test_batched_gemm_softmax_gemm
)
add_custom_target
(
test_batched_gemm_
masking_scale_
softmax_gemm
_permute
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_masking_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
add_dependencies
(
test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_softmax_gemm_fp16
)
\ No newline at end of file
\ No newline at end of file
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_softmax_gemm_fp16.cpp
→
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_
masking_scale_
softmax_gemm_
permute_
fp16.cpp
View file @
200dd06b
File moved
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_softmax_gemm_util.hpp
→
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_
masking_scale_
softmax_gemm_
permute_
util.hpp
View file @
200dd06b
...
@@ -5,8 +5,8 @@
...
@@ -5,8 +5,8 @@
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_
permute_
xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
#include "profiler/include/profile_batched_gemm_
masking_scale_
softmax_gemm_
permute_
impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
template
<
ck
::
index_t
N
>
...
@@ -20,37 +20,37 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -20,37 +20,37 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
typename
Tuple
>
template
<
typename
Tuple
>
struct
TestBatchedGemmSoftmaxGemm
:
public
::
testing
::
Test
struct
TestBatchedGemmSoftmaxGemm
:
public
::
testing
::
Test
{
{
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
B0Layout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B0Layout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
C
Layout
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
using
C
PermuteNumDims_G_M_O
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
64
,
64
,
6
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
256
,
256
,
128
,
128
,
4
,
6
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
64
,
64
,
3
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
512
,
512
,
128
,
128
,
2
,
3
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
64
,
64
,
3
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
{
1024
,
1024
,
128
,
128
,
1
,
1
},
};
};
bool
bench_
=
false
;
bool
bench_
=
false
;
bool
verify_
=
true
;
bool
verify_
=
true
;
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
BatchCount
)
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
{
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_softmax_gemm_impl
<
ADataType
,
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_
masking_scale_
softmax_gemm_
permute_
impl
<
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
B0Layout
,
B0Layout
,
B1Layout
,
B1Layout
,
C
Layout
>
(
C
PermuteNumDims_G_M_O
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
BatchCount
);
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
...
@@ -63,9 +63,10 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
...
@@ -63,9 +63,10 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
int
N
=
lengths
[
1
];
int
N
=
lengths
[
1
];
int
K
=
lengths
[
2
];
int
K
=
lengths
[
2
];
int
O
=
lengths
[
3
];
int
O
=
lengths
[
3
];
int
BatchCount
=
lengths
[
4
];
int
G0
=
lengths
[
4
];
int
G1
=
lengths
[
5
];
this
->
RunSingle
(
M
,
N
,
K
,
O
,
BatchCount
);
this
->
RunSingle
(
M
,
N
,
K
,
O
,
G0
,
G1
);
}
}
}
}
};
};
...
@@ -74,36 +75,38 @@ template <GemmSpecialization GemmSpec>
...
@@ -74,36 +75,38 @@ template <GemmSpecialization GemmSpec>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CShuffleDataType
=
F16
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
Scale
;
using
B1ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
Permute
_Xdl_CShuffle
<
ALayout
,
ALayout
,
B0Layout
,
B0Layout
,
B1Layout
,
B1Layout
,
C
Layout
,
C
PermuteNumDims_G_M_O
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
...
@@ -155,7 +158,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
...
@@ -155,7 +158,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
true
>
;
// Masking
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
{
...
@@ -170,6 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
...
@@ -170,6 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
K
,
K
,
O
,
O
,
0
,
// BatchCount
0
,
// BatchCount
{
0
,
0
,
M
,
O
},
// gs ms ns lengths
{
0
,
O
,
0
,
1
},
// gs ms ns strides
0
,
// StrideA
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideB1
...
@@ -180,7 +186,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
...
@@ -180,7 +186,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0
,
// BatchStrideC
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
Scale
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
PassThrough
{});
// c_element_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment