Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b41e6019
Commit
b41e6019
authored
Sep 06, 2022
by
Po-Yen, Chen
Browse files
Merge branch 'develop' into feature/add-permute-device-op
parents
d356c871
868e5c55
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
525 additions
and
95 deletions
+525
-95
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
+80
-0
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+15
-2
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp
...nce/gpu/normalization/device_softmax_f16_f16_instance.cpp
+25
-17
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp
...nce/gpu/normalization/device_softmax_f32_f32_instance.cpp
+23
-15
profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp
profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp
+10
-3
profiler/include/profile_normalization_impl.hpp
profiler/include/profile_normalization_impl.hpp
+37
-16
profiler/src/profile_normalization.cpp
profiler/src/profile_normalization.cpp
+65
-23
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+5
-4
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
+122
-0
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
+121
-0
test/softmax/test_softmax_util.hpp
test/softmax/test_softmax_util.hpp
+22
-15
No files found.
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
0 → 100644
View file @
b41e6019
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can trigger compiler crash in mainline #9110 but not in #10738
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
4
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
4
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
void
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
Row
,
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
b41e6019
...
@@ -26,6 +26,8 @@ using S = ck::Sequence<Is...>;
...
@@ -26,6 +26,8 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNOPadding
;
// Padding K is currently flawed
// c[g, m, n] = a[g, m, k] * b[g, n, k]
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
using
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
...
@@ -35,10 +37,21 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
...
@@ -35,10 +37,21 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
// clang-format on
>
;
>
;
...
...
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f16_f16_instance.cpp
View file @
b41e6019
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include <tuple>
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include <vector>
#include "ck/utility/data_type.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
namespace
{
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
template
<
index_t
Rank
,
index_t
Reduce
>
template
<
index_t
Rank
,
index_t
Reduce
>
using
device_softmax_f16_f16_instances
=
std
::
tuple
<
using
device_softmax_f16_f16_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//
InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
//
InDataType, AccDataType, OutDataType,
InElementwiseOp, AccElementwiseOp,
Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
8
>
,
DeviceSoftmax
<
F16
,
F32
,
F16
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
8
>
DeviceSoftmax
Impl
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
8
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
3
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
3
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
3
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
3
,
2
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
3
,
2
>
{});
}
}
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
Pass
,
Pass
,
4
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
4
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
4
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
4
,
2
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f16_f16_instances
<
4
,
2
>
{});
...
...
library/src/tensor_operation_instance/gpu/normalization/device_softmax_f32_f32_instance.cpp
View file @
b41e6019
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_softmax
_impl
.hpp"
#include "ck/
utility/data_type
.hpp"
#include "ck/
tensor_operation/gpu/element/unary_element_wise_operation
.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
using
F32
=
float
;
namespace
{
using
F32
=
float
;
using
Pass
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
template
<
index_t
Rank
,
index_t
Reduce
>
template
<
index_t
Rank
,
index_t
Reduce
>
using
device_softmax_f32_f32_instances
=
std
::
tuple
<
using
device_softmax_f32_f32_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//
InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
//
InDataType, AccDataType, OutDataType,
InElementwiseOp, AccElementwiseOp,
Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
>
,
// fallback kernel
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
4
>
,
DeviceSoftmax
<
F32
,
F32
,
F32
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
4
>
DeviceSoftmax
Impl
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
4
>
// clang-format on
// clang-format on
>
;
>
;
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
3
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
3
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
3
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
3
,
2
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
3
,
2
>
{});
}
}
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
instances
)
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
Pass
,
Pass
,
4
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
4
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
4
,
1
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
4
,
2
>
{});
add_device_operation_instances
(
instances
,
device_softmax_f32_f32_instances
<
4
,
2
>
{});
...
...
profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
b41e6019
...
@@ -147,9 +147,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
...
@@ -147,9 +147,16 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
// Still unsure whether this kind of deterministic floating point accurary issue is expected
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
// or not. May want to try exact same approach as the GPU kernel in the host reference
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
5
,
5
});
// GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then,
// shrink the input value range as it is less likely to produce errors of around ~1e-3.
// a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
// b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
// b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
...
profiler/include/profile_normalization_impl.hpp
View file @
b41e6019
...
@@ -6,25 +6,36 @@
...
@@ -6,25 +6,36 @@
#include <iomanip>
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_softmax_f16_f16_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
namespace
{
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
void
add_device_softmax_f16_f16_rank3_instances
(
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceNormalizationPtr
>&
);
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f16_f16_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
4
>>&
);
void
add_device_softmax_f32_f32_rank3_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
3
>>&
);
void
add_device_softmax_f32_f32_rank4_instances
(
std
::
vector
<
DeviceSoftmaxPtr
<
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
4
>>&
);
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
...
@@ -57,7 +68,7 @@ template <> std::string type_to_string<int8_t>() { return "int8"; }
...
@@ -57,7 +68,7 @@ template <> std::string type_to_string<int8_t>() { return "int8"; }
template
<
>
std
::
string
type_to_string
<
int32_t
>
()
{
return
"int32"
;
}
template
<
>
std
::
string
type_to_string
<
int32_t
>
()
{
return
"int32"
;
}
// clang-format on
// clang-format on
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
>
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
>
void
profile_normalization_impl
(
int
do_verification
,
void
profile_normalization_impl
(
int
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
...
@@ -69,6 +80,11 @@ void profile_normalization_impl(int do_verification,
...
@@ -69,6 +80,11 @@ void profile_normalization_impl(int do_verification,
AccDataType
beta
,
AccDataType
beta
,
NormType
norm_type
)
NormType
norm_type
)
{
{
if
(
Rank
!=
in_length
.
size
())
{
throw
std
::
runtime_error
(
"Input tensor rank is different from template argument Rank!"
);
}
Tensor
<
InDataType
>
in
=
in_strides
.
empty
()
?
Tensor
<
InDataType
>
(
in_length
)
Tensor
<
InDataType
>
in
=
in_strides
.
empty
()
?
Tensor
<
InDataType
>
(
in_length
)
:
Tensor
<
InDataType
>
(
in_length
,
in_strides
);
:
Tensor
<
InDataType
>
(
in_length
,
in_strides
);
Tensor
<
OutDataType
>
out
(
in
.
mDesc
);
Tensor
<
OutDataType
>
out
(
in
.
mDesc
);
...
@@ -99,30 +115,31 @@ void profile_normalization_impl(int do_verification,
...
@@ -99,30 +115,31 @@ void profile_normalization_impl(int do_verification,
std
::
vector
<
index_t
>
i_in_lengths
(
in
.
mDesc
.
GetLengths
().
begin
(),
in
.
mDesc
.
GetLengths
().
end
());
std
::
vector
<
index_t
>
i_in_lengths
(
in
.
mDesc
.
GetLengths
().
begin
(),
in
.
mDesc
.
GetLengths
().
end
());
std
::
vector
<
index_t
>
i_in_strides
(
in
.
mDesc
.
GetStrides
().
begin
(),
in
.
mDesc
.
GetStrides
().
end
());
std
::
vector
<
index_t
>
i_in_strides
(
in
.
mDesc
.
GetStrides
().
begin
(),
in
.
mDesc
.
GetStrides
().
end
());
// add device normalization instances
// add device softmax instances
std
::
vector
<
tensor_operation
::
device
::
DeviceNormalizationPtr
>
instances
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceOpPtr
=
tensor_operation
::
device
::
DeviceSoftmaxPtr
<
InDataType
,
AccDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
Rank
>
;
std
::
vector
<
DeviceOpPtr
>
instances
;
if
(
norm_type
==
NormType
::
SOFTMAX
)
if
(
norm_type
==
NormType
::
SOFTMAX
)
{
{
if
constexpr
(
is_same
<
InDataType
,
half_t
>::
value
&&
is_same
<
OutDataType
,
half_t
>::
value
&&
if
constexpr
(
is_same
<
InDataType
,
half_t
>::
value
&&
is_same
<
OutDataType
,
half_t
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
is_same
<
AccDataType
,
float
>::
value
)
{
{
if
(
in_length
.
size
()
==
3
)
if
constexpr
(
Rank
==
3
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank3_instances
(
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank3_instances
(
instances
);
instances
);
else
if
constexpr
(
Rank
==
4
)
if
(
in_length
.
size
()
==
4
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank4_instances
(
tensor_operation
::
device
::
instance
::
add_device_softmax_f16_f16_rank4_instances
(
instances
);
instances
);
}
}
else
if
constexpr
(
is_same
<
InDataType
,
float
>::
value
&&
is_same
<
OutDataType
,
float
>::
value
&&
else
if
constexpr
(
is_same
<
InDataType
,
float
>::
value
&&
is_same
<
OutDataType
,
float
>::
value
&&
is_same
<
AccDataType
,
float
>::
value
)
is_same
<
AccDataType
,
float
>::
value
)
{
{
if
(
in_length
.
size
()
==
3
)
if
constexpr
(
Rank
==
3
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank3_instances
(
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank3_instances
(
instances
);
instances
);
else
if
constexpr
(
Rank
==
4
)
if
(
in_length
.
size
()
==
4
)
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank4_instances
(
tensor_operation
::
device
::
instance
::
add_device_softmax_f32_f32_rank4_instances
(
instances
);
instances
);
}
}
...
@@ -137,6 +154,8 @@ void profile_normalization_impl(int do_verification,
...
@@ -137,6 +154,8 @@ void profile_normalization_impl(int do_verification,
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_gb_per_sec
=
0
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
for
(
auto
&
inst_ptr
:
instances
)
for
(
auto
&
inst_ptr
:
instances
)
{
{
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
...
@@ -153,7 +172,9 @@ void profile_normalization_impl(int do_verification,
...
@@ -153,7 +172,9 @@ void profile_normalization_impl(int do_verification,
&
alpha
,
&
alpha
,
&
beta
,
&
beta
,
in_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
());
out_dev
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
if
(
!
inst_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
inst_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
...
profiler/src/profile_normalization.cpp
View file @
b41e6019
...
@@ -50,7 +50,7 @@ struct ArgParser
...
@@ -50,7 +50,7 @@ struct ArgParser
void
print_help
()
void
print_help
()
{
{
std
::
cout
<<
"arg1: tensor operation (
layernorm/
batchnorm/softmax)
\n
"
std
::
cout
<<
"arg1: tensor operation (batchnorm/softmax)
\n
"
<<
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
<<
"arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)
\n
"
<<
"arg3: verification (0: no; 1: yes)
\n
"
<<
"arg3: verification (0: no; 1: yes)
\n
"
<<
"arg4: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg4: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
...
@@ -91,31 +91,73 @@ int profile_normalization(int argc, char* argv[])
...
@@ -91,31 +91,73 @@ int profile_normalization(int argc, char* argv[])
arg_parser
.
long_opts
[
"alpha"
].
empty
()
?
1
:
arg_parser
.
long_opts
[
"alpha"
][
0
];
arg_parser
.
long_opts
[
"alpha"
].
empty
()
?
1
:
arg_parser
.
long_opts
[
"alpha"
][
0
];
const
index_t
beta
=
arg_parser
.
long_opts
[
"beta"
].
empty
()
?
0
:
arg_parser
.
long_opts
[
"beta"
][
0
];
const
index_t
beta
=
arg_parser
.
long_opts
[
"beta"
].
empty
()
?
0
:
arg_parser
.
long_opts
[
"beta"
][
0
];
if
(
data_type
==
NormDataType
::
F16_F16
)
if
(
length
.
size
()
==
3
)
{
{
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
>
(
do_verification
,
if
(
data_type
==
NormDataType
::
F16_F16
)
init_method
,
{
do_log
,
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
3
>
(
time_kernel
,
do_verification
,
length
,
init_method
,
stride
,
do_log
,
reduce
,
time_kernel
,
float
(
alpha
),
length
,
float
(
beta
),
stride
,
norm_type
);
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
,
3
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
}
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
else
if
(
length
.
size
()
==
4
)
{
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
>
(
do_verification
,
if
(
data_type
==
NormDataType
::
F16_F16
)
init_method
,
{
do_log
,
ck
::
profiler
::
profile_normalization_impl
<
ck
::
half_t
,
float
,
ck
::
half_t
,
4
>
(
time_kernel
,
do_verification
,
length
,
init_method
,
stride
,
do_log
,
reduce
,
time_kernel
,
float
(
alpha
),
length
,
float
(
beta
),
stride
,
norm_type
);
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
if
(
data_type
==
NormDataType
::
F32_F32
)
{
ck
::
profiler
::
profile_normalization_impl
<
float
,
float
,
float
,
4
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
length
,
stride
,
reduce
,
float
(
alpha
),
float
(
beta
),
norm_type
);
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
}
}
else
else
{
{
...
...
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
View file @
b41e6019
...
@@ -11,7 +11,8 @@ class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
...
@@ -11,7 +11,8 @@ class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Col
,
Row
>
>
;
>
;
// clang-format on
// clang-format on
...
@@ -68,7 +69,6 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
...
@@ -68,7 +69,6 @@ TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
this
->
Run
();
this
->
Run
();
}
}
// Currently expected that no kernels can support this case
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddK
)
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddK
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
...
@@ -140,9 +140,10 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
...
@@ -140,9 +140,10 @@ TEST(TestBatchedGemmGemmInterface, GemmSpecializationSizeMismatch)
// clang-format off
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
// Kernel can't support odd K because
K must be integer multiples of K1 values of either A or B
// Kernel can't support odd K
size
because
SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
// Kernel can't support odd O size because it must satisfy SizeO % B1SrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
130
,
128
));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
// clang-format on
// clang-format on
}
}
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
View file @
b41e6019
...
@@ -19,6 +19,73 @@ TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes);
...
@@ -19,6 +19,73 @@ TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes);
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
136
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
129
,
128
,
1
},
};
this
->
Run
();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16
)
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16
)
{
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
...
@@ -37,3 +104,58 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
...
@@ -37,3 +104,58 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this
->
verify_
=
false
;
this
->
verify_
=
false
;
this
->
Run
();
this
->
Run
();
}
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
// EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
TEST
(
TestBatchedGemmSoftmaxGemmInterface
,
GemmSpecializationSizeMismatch
)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
AdhocTest
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
49
,
49
,
64
,
64
,
24
},
{
64
,
49
,
64
,
64
,
24
},
{
1020
,
1020
,
64
,
128
,
24
},
{
576
,
576
,
64
,
64
,
24
},
};
this
->
bench_
=
true
;
this
->
Run
();
}
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
View file @
b41e6019
...
@@ -4,7 +4,10 @@
...
@@ -4,7 +4,10 @@
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
using
I
=
ck
::
Number
<
N
>
;
...
@@ -66,3 +69,121 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
...
@@ -66,3 +69,121 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
}
}
}
}
};
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
test/softmax/test_softmax_util.hpp
View file @
b41e6019
...
@@ -9,7 +9,8 @@
...
@@ -9,7 +9,8 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
...
@@ -51,19 +52,23 @@ class TestSoftmax : public ::testing::Test
...
@@ -51,19 +52,23 @@ class TestSoftmax : public ::testing::Test
using
ReferenceInstance
=
using
ReferenceInstance
=
tensor_operation
::
host
::
ReferenceSoftmax
<
InDataType
,
OutDataType
,
AccDataType
>
;
tensor_operation
::
host
::
ReferenceSoftmax
<
InDataType
,
OutDataType
,
AccDataType
>
;
using
DeviceInstance
=
tensor_operation
::
device
::
DeviceSoftmax
<
InDataType
,
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
AccDataType
,
OutDataType
,
using
DeviceInstance
=
tensor_operation
::
device
::
DeviceSoftmaxImpl
<
InDataType
,
Rank
,
AccDataType
,
NumReduceDim
,
OutDataType
,
BlockSize
,
PassThrough
,
MThreadClusterSize
,
PassThrough
,
KThreadClusterSize
,
Rank
,
MThreadSliceSize
,
NumReduceDim
,
KThreadSliceSize
,
BlockSize
,
InSrcVectorDim
,
MThreadClusterSize
,
InSrcVectorSize
,
KThreadClusterSize
,
OutDstVectorSize
>
;
MThreadSliceSize
,
KThreadSliceSize
,
InSrcVectorDim
,
InSrcVectorSize
,
OutDstVectorSize
>
;
TestSoftmax
()
:
ref_instance_invoker_
(
ReferenceInstance
{}.
MakeInvoker
())
{}
TestSoftmax
()
:
ref_instance_invoker_
(
ReferenceInstance
{}.
MakeInvoker
())
{}
...
@@ -97,7 +102,9 @@ class TestSoftmax : public ::testing::Test
...
@@ -97,7 +102,9 @@ class TestSoftmax : public ::testing::Test
&
alpha
,
&
alpha
,
&
beta
,
&
beta
,
in_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
());
out_dev
.
GetDeviceBuffer
(),
PassThrough
{},
PassThrough
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment