Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6e0a93d2
Commit
6e0a93d2
authored
Sep 16, 2022
by
wangshaojie6
Browse files
add test
parent
8cdcad67
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
391 additions
and
12 deletions
+391
-12
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
+36
-12
test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
...ed_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
+5
-0
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_softmax_gemm_fp16.cpp
...tmax_gemm_permute/test_batched_gemm_softmax_gemm_fp16.cpp
+161
-0
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_softmax_gemm_util.hpp
...tmax_gemm_permute/test_batched_gemm_softmax_gemm_util.hpp
+189
-0
No files found.
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
View file @
6e0a93d2
...
@@ -38,7 +38,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -38,7 +38,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
int
N
,
int
N
,
int
K
,
int
K
,
int
O
,
int
O
,
int
BatchCount
=
1
,
int
G0
,
int
G1
,
int
StrideA
=
-
1
,
int
StrideA
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB1
=
-
1
,
int
StrideB1
=
-
1
,
...
@@ -46,7 +47,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -46,7 +47,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
int
BatchStrideA
=
-
1
,
int
BatchStrideA
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideC
=
-
1
)
int
BatchStrideC
=
-
1
,
float
alpha
=
1.
f
)
{
{
...
@@ -68,7 +70,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -68,7 +70,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
B0ElementOp
,
B0ElementOp
,
C
ElementOp
>
;
Acc0
ElementOp
>
;
// Ref Softmax: fp32 in, various type out
// Ref Softmax: fp32 in, various type out
using
ReferenceSoftmaxInstance
=
using
ReferenceSoftmaxInstance
=
...
@@ -85,6 +87,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -85,6 +87,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
bool
pass
=
true
;
bool
pass
=
true
;
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB0
=
ck
::
is_same_v
<
B0Layout
,
Row
>
?
N
:
K
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
const
int
DefaultStrideB1
=
ck
::
is_same_v
<
B1Layout
,
Row
>
?
O
:
N
;
...
@@ -105,6 +110,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -105,6 +110,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideB1
=
BatchStrideB1
<
0
?
DefaultBatchStrideB1
:
BatchStrideB1
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
BatchStrideC
=
BatchStrideC
<
0
?
DefaultBatchStrideC
:
BatchStrideC
;
const
int
BatchCount
=
G0
*
G1
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
col
,
...
@@ -130,18 +137,22 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -130,18 +137,22 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB0
,
BatchStrideB0
,
B0Layout
{}));
Tensor
<
B1DataType
>
b1_g_n_o
(
Tensor
<
B1DataType
>
b1_g_n_o
(
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
f_host_tensor_descriptor
(
BatchCount
,
N
,
O
,
StrideB1
,
BatchStrideB1
,
B1Layout
{}));
Tensor
<
CDataType
>
c_g_m_o_host_result
(
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_lengths
.
begin
(),
c_gs_ms_os_lengths
.
end
()),
Tensor
<
CDataType
>
c_g_m_o_device_result
(
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_strides
.
begin
(),
c_gs_ms_os_strides
.
end
()));
f_host_tensor_descriptor
(
BatchCount
,
M
,
O
,
StrideC
,
BatchStrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_lengths
.
begin
(),
c_gs_ms_os_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
c_gs_ms_os_strides
.
begin
(),
c_gs_ms_os_strides
.
end
()));
// Host verification: Output of Gemm0 is input A of Gemm1
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor
<
AccDataType
>
acc0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
AccDataType
>
acc0_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
ADataType
>
a1_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
ADataType
>
a1_g_m_n
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
N
,
M
*
N
,
Row
{}));
Tensor
<
CDataType
>
c_g_m_o_host_result
(
std
::
vector
<
int
>
{
BatchCount
,
M
,
O
},
std
::
vector
<
int
>
{
M
*
O
,
O
,
1
});
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_g_k_n: "
<<
b0_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_g_n_o: "
<<
b1_g_n_o
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_g_m_o: "
<<
c_g_m_o_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_g
s
_m
s
_o
s
: "
<<
c_g
s
_m
s
_o
s
_host_result
.
mDesc
<<
std
::
endl
;
std
::
srand
(
1
);
// work around test flakiness
std
::
srand
(
1
);
// work around test flakiness
switch
(
init_method
)
switch
(
init_method
)
...
@@ -178,7 +189,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -178,7 +189,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSize
());
DeviceMem
a_g_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSize
());
DeviceMem
b0_g_k_n_device_buf
(
sizeof
(
B0DataType
)
*
b0_g_k_n
.
mDesc
.
GetElementSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSize
());
DeviceMem
b1_g_n_o_device_buf
(
sizeof
(
B1DataType
)
*
b1_g_n_o
.
mDesc
.
GetElementSize
());
DeviceMem
c_g_m_o_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_o_device_result
.
mDesc
.
GetElementSize
());
DeviceMem
c_gs_ms_os_device_buf
(
sizeof
(
CDataType
)
*
c_gs_ms_os_device_result
.
mDesc
.
GetElementSpaceSize
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
a_g_m_k_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
b0_g_k_n_device_buf
.
ToDevice
(
b0_g_k_n
.
mData
.
data
());
...
@@ -220,7 +232,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -220,7 +232,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// mask out upper triangle
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
auto
ref_softmax_invoker
=
ref_softmax
.
MakeInvoker
();
...
@@ -234,6 +248,16 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -234,6 +248,16 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
a1_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
a1_g_m_n
,
b1_g_n_o
,
c_g_m_o_host_result
,
PassThrough
{},
b1_element_op
,
c_element_op
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
}
}
std
::
string
best_op_name
;
std
::
string
best_op_name
;
...
@@ -302,7 +326,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -302,7 +326,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
pass
=
pass
&
pass
=
pass
&
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_g_m_o_host_result
.
mData
);
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_g
s
_m
s
_o
s
_host_result
.
mData
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -313,7 +337,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -313,7 +337,7 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b1_g_n_o : "
,
b1_g_n_o
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b1_g_n_o : "
,
b1_g_n_o
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_g_m_o_host_result : "
,
c_g_m_o_host_result
.
mData
,
","
)
std
::
cout
<<
"c_g
s
_m
s
_o
s
_host_result : "
,
c_g
s
_m
s
_o
s
_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_g_m_o_device_result : "
,
c_g_m_o_device_result
.
mData
,
","
)
std
::
cout
<<
"c_g_m_o_device_result : "
,
c_g_m_o_device_result
.
mData
,
","
)
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/CMakeLists.txt
0 → 100644
View file @
6e0a93d2
add_custom_target
(
test_batched_gemm_softmax_gemm
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
\ No newline at end of file
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_softmax_gemm_fp16.cpp
0 → 100644
View file @
6e0a93d2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_softmax_gemm_util.hpp"
template
<
typename
Tuple
>
class
TestBatchedGemmSoftmaxGemmFP16
:
public
TestBatchedGemmSoftmaxGemm
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestBatchedGemmSoftmaxGemmFP16
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
136
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
129
,
128
,
1
},
};
this
->
Run
();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
256
,
256
,
64
,
64
,
768
},
{
256
,
256
,
128
,
128
,
768
},
{
512
,
512
,
64
,
64
,
768
},
{
512
,
512
,
128
,
128
,
768
},
{
1024
,
1024
,
64
,
64
,
768
},
{
1024
,
1024
,
128
,
128
,
768
},
{
2048
,
2048
,
64
,
64
,
768
},
{
2048
,
2048
,
128
,
128
,
768
},
{
4096
,
4096
,
64
,
64
,
768
},
{
4096
,
4096
,
128
,
128
,
768
},
};
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMismatch
)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
AdhocTest
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
49
,
49
,
64
,
64
,
24
},
{
64
,
49
,
64
,
64
,
24
},
{
1020
,
1020
,
64
,
128
,
24
},
{
576
,
576
,
64
,
64
,
24
},
};
this
->
bench_
=
true
;
this
->
Run
();
}
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_softmax_gemm_util.hpp
0 → 100644
View file @
6e0a93d2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
using
F16
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
typename
Tuple
>
struct
TestBatchedGemmSoftmaxGemm
:
public
::
testing
::
Test
{
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
B0Layout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
};
bool
bench_
=
false
;
bool
verify_
=
true
;
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
BatchCount
)
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_softmax_gemm_impl
<
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ALayout
,
B0Layout
,
B1Layout
,
CLayout
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
BatchCount
);
EXPECT_TRUE
(
pass
);
}
void
Run
()
{
for
(
auto
lengths
:
this
->
lengths_
)
{
int
M
=
lengths
[
0
];
int
N
=
lengths
[
1
];
int
K
=
lengths
[
2
];
int
O
=
lengths
[
3
];
int
BatchCount
=
lengths
[
4
];
this
->
RunSingle
(
M
,
N
,
K
,
O
,
BatchCount
);
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment