Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
97dcc7b2
"tests/vscode:/vscode.git/clone" did not exist on "306a7bd0475c4af03024057277b6454855e9ea1b"
Commit
97dcc7b2
authored
Sep 16, 2022
by
wangshaojie6
Browse files
add gtest for bmm masking scale softmax bmm permute
parent
200dd06b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
66 deletions
+63
-66
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
...e/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
...ed_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
+2
-2
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
+8
-14
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
...ed_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
+3
-3
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
..._batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
+45
-41
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
+2
-4
No files found.
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
View file @
97dcc7b2
...
@@ -50,7 +50,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -50,7 +50,7 @@ struct DeviceOperationInstanceFactory<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B0Layout
,
B1Layout
,
B1Layout
,
C
Layout
,
C
PermuteNumDims_G_M_Gemm1N
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
...
@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
C
Layout
,
Row
>
)
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
C
PermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
{
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
op_ptrs
);
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
View file @
97dcc7b2
add_instance_library
(
device_batched_gemm_masking_softmax_gemm_permute_instance
add_instance_library
(
device_batched_gemm_masking_
scale_
softmax_gemm_permute_instance
device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_masking_
scale_
softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
)
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
View file @
97dcc7b2
...
@@ -43,11 +43,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -43,11 +43,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
int
StrideA
=
-
1
,
int
StrideA
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB1
=
-
1
,
int
StrideB1
=
-
1
,
int
StrideC
=
-
1
,
int
BatchStrideA
=
-
1
,
int
BatchStrideA
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideC
=
-
1
,
float
alpha
=
1.
f
)
float
alpha
=
1.
f
)
{
{
...
@@ -93,22 +91,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -93,22 +91,18 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
const
int
DefaultStrideC
=
ck
::
is_same_v
<
CLayout
,
Row
>
?
O
:
M
;
StrideA
=
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
;
StrideA
=
(
StrideA
<
0
)
?
DefaultStrideA
:
StrideA
;
StrideB0
=
(
StrideB0
<
0
)
?
DefaultStrideB0
:
StrideB0
;
StrideB0
=
(
StrideB0
<
0
)
?
DefaultStrideB0
:
StrideB0
;
StrideB1
=
(
StrideB1
<
0
)
?
DefaultStrideB1
:
StrideB1
;
StrideB1
=
(
StrideB1
<
0
)
?
DefaultStrideB1
:
StrideB1
;
StrideC
=
(
StrideC
<
0
)
?
DefaultStrideC
:
StrideC
;
const
int
DefaultBatchStrideA
=
(
ck
::
is_same_v
<
ALayout
,
Col
>
?
K
:
M
)
*
StrideA
;
const
int
DefaultBatchStrideA
=
(
ck
::
is_same_v
<
ALayout
,
Col
>
?
K
:
M
)
*
StrideA
;
const
int
DefaultBatchStrideB0
=
(
ck
::
is_same_v
<
B0Layout
,
Col
>
?
N
:
K
)
*
StrideB0
;
const
int
DefaultBatchStrideB0
=
(
ck
::
is_same_v
<
B0Layout
,
Col
>
?
N
:
K
)
*
StrideB0
;
const
int
DefaultBatchStrideB1
=
(
ck
::
is_same_v
<
B1Layout
,
Col
>
?
O
:
N
)
*
StrideB1
;
const
int
DefaultBatchStrideB1
=
(
ck
::
is_same_v
<
B1Layout
,
Col
>
?
O
:
N
)
*
StrideB1
;
const
int
DefaultBatchStrideC
=
(
ck
::
is_same_v
<
CLayout
,
Col
>
?
O
:
M
)
*
StrideC
;
BatchStrideA
=
BatchStrideA
<
0
?
DefaultBatchStrideA
:
BatchStrideA
;
BatchStrideA
=
BatchStrideA
<
0
?
DefaultBatchStrideA
:
BatchStrideA
;
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideB0
=
BatchStrideB0
<
0
?
DefaultBatchStrideB0
:
BatchStrideB0
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
const
int
BatchCount
=
G0
*
G1
;
const
int
BatchCount
=
G0
*
G1
;
...
@@ -198,7 +192,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -198,7 +192,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
b0_element_op
=
B0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{};
auto
acc0_element_op
=
Acc0ElementOp
{
alpha
};
auto
b1_element_op
=
B1ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
...
@@ -227,7 +221,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -227,7 +221,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_invoker
=
ref_gemm0
.
MakeInvoker
();
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
auto
ref_gemm0_argument
=
ref_gemm0
.
MakeArgument
(
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
PassThrough
{
});
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
Scale
{
alpha
});
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
...
@@ -272,20 +266,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -272,20 +266,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_g_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_g_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_g_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g_m_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_g
s
_m
s
_o
s
_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
O
,
O
,
BatchCount
,
BatchCount
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
StrideA
,
StrideA
,
StrideB0
,
StrideB0
,
StrideB1
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideA
,
BatchStrideB0
,
BatchStrideB0
,
BatchStrideB1
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
a_element_op
,
b0_element_op
,
b0_element_op
,
acc0_element_op
,
acc0_element_op
,
...
@@ -323,10 +317,10 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -323,10 +317,10 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
if
(
do_verification
)
if
(
do_verification
)
{
{
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g
s
_m
s
_o
s
_device_buf
.
FromDevice
(
c_g
s
_m
s
_o
s
_device_result
.
mData
.
data
());
pass
=
pass
&
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
ck
::
utils
::
check_err
(
c_g
s
_m
s
_o
s
_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -340,7 +334,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -340,7 +334,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
std
::
cout
<<
"c_gs_ms_os_host_result : "
,
c_gs_ms_os_host_result
.
mData
,
","
)
std
::
cout
<<
"c_gs_ms_os_host_result : "
,
c_gs_ms_os_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_g_m_o_device_result : "
,
c_g_m_o_device_result
.
mData
,
","
)
std
::
cout
<<
"c_g
s
_m
s
_o
s
_device_result : "
,
c_g
s
_m
s
_o
s
_device_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
}
}
}
}
...
...
test/CMakeLists.txt
View file @
97dcc7b2
...
@@ -42,6 +42,7 @@ add_subdirectory(batched_gemm)
...
@@ -42,6 +42,7 @@ add_subdirectory(batched_gemm)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
batched_gemm_gemm
)
add_subdirectory
(
batched_gemm_softmax_gemm
)
add_subdirectory
(
batched_gemm_softmax_gemm
)
add_subdirectory
(
batched_gemm_masking_scale_softmax_gemm_permute
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
reduce
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
convnd_fwd
)
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
View file @
97dcc7b2
add_custom_target
(
test_batched_gemm_masking_scale_softmax_gemm_permute
)
add_custom_target
(
test_batched_gemm_masking_scale_softmax_gemm_permute
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_masking_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_softmax_gemm_fp16
)
add_dependencies
(
test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16
)
\ No newline at end of file
\ No newline at end of file
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
View file @
97dcc7b2
...
@@ -2,103 +2,107 @@
...
@@ -2,103 +2,107 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "test_batched_gemm_softmax_gemm_util.hpp"
#include "test_batched_gemm_
masking_scale_
softmax_gemm_
permute_
util.hpp"
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestBatchedGemmSoftmaxGemmFP16
:
public
TestBatchedGemmSoftmaxGemm
<
Tuple
>
class
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
:
public
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
<
Tuple
>
{
{
};
};
// clang-format off
// clang-format off
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
CPermuteNumDims_G_M_O
>
>
;
>
;
// clang-format on
// clang-format on
TYPED_TEST_SUITE
(
TestBatchedGemmSoftmaxGemmFP16
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadM
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_PadM
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
{
136
,
128
,
32
,
128
,
2
,
3
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadN
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_PadN
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
{
128
,
136
,
32
,
128
,
3
,
2
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadK
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_PadK
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
40
,
128
,
2
,
4
},
{
128
,
128
,
136
,
128
,
1
},
{
128
,
128
,
136
,
128
,
4
,
2
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadO
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_PadO
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
{
128
,
128
,
32
,
136
,
1
,
3
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddM
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_OddM
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
{
129
,
128
,
32
,
128
,
2
,
3
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddN
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_OddN
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
{
128
,
129
,
32
,
128
,
4
,
3
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddK
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_OddK
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
33
,
128
,
2
,
3
},
{
128
,
128
,
129
,
128
,
1
},
{
128
,
128
,
129
,
128
,
2
,
3
},
};
};
this
->
Run
();
this
->
Run
();
}
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddO
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
Test_FP16_OddO
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
{
128
,
128
,
32
,
129
,
2
,
3
},
};
};
this
->
Run
();
this
->
Run
();
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
DISABLED_Bench_FP16
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
256
,
256
,
64
,
64
,
768
},
{
256
,
256
,
64
,
64
,
48
,
16
},
{
256
,
256
,
128
,
128
,
768
},
{
256
,
256
,
128
,
128
,
48
,
16
},
{
512
,
512
,
64
,
64
,
768
},
{
512
,
512
,
64
,
64
,
48
,
16
},
{
512
,
512
,
128
,
128
,
768
},
{
512
,
512
,
128
,
128
,
48
,
16
},
{
1024
,
1024
,
64
,
64
,
768
},
{
1024
,
1024
,
64
,
64
,
48
,
16
},
{
1024
,
1024
,
128
,
128
,
768
},
{
1024
,
1024
,
128
,
128
,
48
,
16
},
{
2048
,
2048
,
64
,
64
,
768
},
{
2048
,
2048
,
64
,
64
,
48
,
16
},
{
2048
,
2048
,
128
,
128
,
768
},
{
2048
,
2048
,
128
,
128
,
48
,
16
},
{
4096
,
4096
,
64
,
64
,
768
},
{
4096
,
4096
,
64
,
64
,
48
,
16
},
{
4096
,
4096
,
128
,
128
,
768
},
{
4096
,
4096
,
128
,
128
,
48
,
16
},
};
};
this
->
bench_
=
true
;
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
verify_
=
false
;
...
@@ -108,7 +112,7 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
...
@@ -108,7 +112,7 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
// TODO: enable KPadding tests when it is implemented
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMatch
)
TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
Interface
,
GemmSpecializationSizeMatch
)
{
{
int
P
=
120
;
// requires padding
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
int
Q
=
128
;
// do not require padding
...
@@ -134,7 +138,7 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
...
@@ -134,7 +138,7 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
// clang-format on
// clang-format on
}
}
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMismatch
)
TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
Interface
,
GemmSpecializationSizeMismatch
)
{
{
// IsSupported(M, N, K, O)
// IsSupported(M, N, K, O)
// clang-format off
// clang-format off
...
@@ -148,13 +152,13 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
...
@@ -148,13 +152,13 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
// clang-format on
// clang-format on
}
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
AdhocTest
)
TYPED_TEST
(
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
FP16
,
AdhocTest
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
49
,
49
,
64
,
64
,
2
4
},
{
49
,
49
,
64
,
64
,
4
,
6
},
{
64
,
49
,
64
,
64
,
2
4
},
{
64
,
49
,
64
,
64
,
4
,
6
},
{
1020
,
1020
,
64
,
128
,
2
4
},
{
1020
,
1020
,
64
,
128
,
4
,
6
},
{
576
,
576
,
64
,
64
,
2
4
},
{
576
,
576
,
64
,
64
,
4
,
6
},
};
};
this
->
bench_
=
true
;
this
->
bench_
=
true
;
this
->
Run
();
this
->
Run
();
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
View file @
97dcc7b2
...
@@ -18,7 +18,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
...
@@ -18,7 +18,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
typename
Tuple
>
template
<
typename
Tuple
>
struct
TestBatchedGemmSoftmaxGemm
:
public
::
testing
::
Test
struct
TestBatchedGemm
MaskingScale
SoftmaxGemm
Permute
:
public
::
testing
::
Test
{
{
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
...
@@ -179,14 +179,12 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
...
@@ -179,14 +179,12 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0
,
// StrideA
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// b0_element_op
Scale
{},
// acc0_element_op
Scale
{
1.
f
},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
PassThrough
{});
// c_element_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment