Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8da05b38
You need to sign in or sign up before continuing.
Unverified
Commit
8da05b38
authored
Mar 05, 2023
by
zjing14
Committed by
GitHub
Mar 05, 2023
Browse files
Merge branch 'develop' into lwpck-586
parents
9a4fd1bc
e6cda9f8
Changes
151
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
684 additions
and
39 deletions
+684
-39
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
+182
-0
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
+380
-0
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
...m_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
+3
-3
test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp
...entwise_normalization/test_elementwise_layernorm_fp16.cpp
+1
-1
test/gemm_layernorm/CMakeLists.txt
test/gemm_layernorm/CMakeLists.txt
+7
-0
test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp
.../gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp
+77
-0
test/normalization/CMakeLists.txt
test/normalization/CMakeLists.txt
+6
-7
test/normalization/test_groupnorm_fp16.cpp
test/normalization/test_groupnorm_fp16.cpp
+7
-7
test/normalization/test_groupnorm_fp32.cpp
test/normalization/test_groupnorm_fp32.cpp
+7
-7
test/normalization/test_layernorm2d_fp16.cpp
test/normalization/test_layernorm2d_fp16.cpp
+7
-7
test/normalization/test_layernorm2d_fp32.cpp
test/normalization/test_layernorm2d_fp32.cpp
+7
-7
No files found.
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
0 → 100644
View file @
8da05b38
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
template
<
typename
Tuple
>
class
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
:
public
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
<
Tuple
>
{
};
using
I1_t
=
ck
::
Number
<
1
>
;
using
I2_t
=
ck
::
Number
<
2
>
;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskOutUpperTriangle_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskOutUpperTriangle_t
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
2
,
3
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
3
,
2
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
2
,
4
},
{
128
,
128
,
136
,
128
,
4
,
2
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
,
3
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
2
,
3
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
4
,
3
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
2
,
3
},
{
128
,
128
,
129
,
128
,
2
,
3
},
};
this
->
Run
();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
2
,
3
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
DISABLED_Bench_FP16_IrregularK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{{
256
,
256
,
160
,
160
,
1
,
16
},
{
256
,
64
,
160
,
64
,
1
,
16
},
{
1024
,
1024
,
80
,
80
,
1
,
16
},
{
1024
,
64
,
80
,
64
,
1
,
16
},
{
4096
,
4096
,
40
,
40
,
1
,
16
},
{
4096
,
64
,
40
,
64
,
1
,
16
}};
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
256
,
256
,
64
,
64
,
48
,
16
},
{
256
,
256
,
128
,
128
,
48
,
16
},
{
512
,
512
,
64
,
64
,
48
,
16
},
{
512
,
512
,
128
,
128
,
48
,
16
},
{
1024
,
1024
,
64
,
64
,
48
,
16
},
{
1024
,
1024
,
128
,
128
,
48
,
16
},
{
2048
,
2048
,
64
,
64
,
48
,
16
},
{
2048
,
2048
,
128
,
128
,
48
,
16
},
{
4096
,
4096
,
64
,
64
,
48
,
16
},
{
4096
,
4096
,
128
,
128
,
48
,
16
},
};
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface
,
GemmSpecializationSizeMismatch
)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
130
,
128
));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
// clang-format on
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
,
AdhocTest
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
49
,
49
,
64
,
64
,
4
,
6
},
{
64
,
49
,
64
,
64
,
4
,
6
},
{
1020
,
1020
,
64
,
128
,
4
,
6
},
{
576
,
576
,
64
,
64
,
4
,
6
},
};
this
->
Run
();
}
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
0 → 100644
View file @
8da05b38
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
using
ck
::
tensor_operation
::
device
::
MaskingSpecialization
;
using
ck
::
tensor_operation
::
device
::
TensorSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
typename
Tuple
>
struct
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
:
public
::
testing
::
Test
{
using
NumDimGType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
NumDimMType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
NumDimNType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
NumDimKType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
NumDimOType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
8
,
Tuple
>
;
using
Acc0BiasDataType
=
std
::
tuple_element_t
<
9
,
Tuple
>
;
using
Acc1BiasDataType
=
std
::
tuple_element_t
<
10
,
Tuple
>
;
using
MaskingType
=
std
::
tuple_element_t
<
11
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
6
,
4
},
{
256
,
256
,
128
,
128
,
4
,
6
},
{
512
,
512
,
64
,
64
,
3
,
2
},
{
512
,
512
,
128
,
128
,
2
,
3
},
{
1024
,
1024
,
64
,
64
,
3
,
1
},
{
1024
,
1024
,
128
,
128
,
1
,
1
},
};
bool
bench_
=
false
;
bool
verify_
=
true
;
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_bias_softmax_gemm_permute_impl
<
NumDimGType
::
value
,
NumDimMType
::
value
,
NumDimNType
::
value
,
NumDimKType
::
value
,
NumDimOType
::
value
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
MaskingType
::
value
>
(
verify_
,
2
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
EXPECT_TRUE
(
pass
);
}
void
Run
()
{
for
(
auto
lengths
:
this
->
lengths_
)
{
int
M
=
lengths
[
0
];
int
N
=
lengths
[
1
];
int
K
=
lengths
[
2
];
int
O
=
lengths
[
3
];
int
G0
=
lengths
[
4
];
int
G1
=
lengths
[
5
];
this
->
RunSingle
(
M
,
N
,
K
,
O
,
G0
,
G1
);
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
F16
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ScaleAdd
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecialization
::
Default
,
// ATensorSpec
TensorSpecialization
::
Default
,
// B0TensorSpec
TensorSpecialization
::
Default
,
// B1TensorSpec
TensorSpecialization
::
Default
,
// CTensorSpec
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// MaskOutUpperTriangle
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
const
int
G0
=
1
,
G1
=
1
;
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B0 layout [G0, N, G1, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B1 layout [G0, N, G1, O]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
};
// C layout [G0, M, G1, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
// D layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
std
::
array
<
void
*
,
1
>
{
nullptr
},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
Acc0ElementOp
{
1.
f
},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
ADataType
=
BF16
;
using
B0DataType
=
BF16
;
using
B1DataType
=
BF16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
BF16
;
using
CDataType
=
BF16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ScaleAdd
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecialization
::
Default
,
// ATensorSpec
TensorSpecialization
::
Default
,
// B0TensorSpec
TensorSpecialization
::
Default
,
// B1TensorSpec
TensorSpecialization
::
Default
,
// CTensorSpec
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
// MaskOutUpperTriangle
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
const
int
G0
=
1
,
G1
=
1
;
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B0 layout [G0, N, G1, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B1 layout [G0, N, G1, O]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
};
// C layout [G0, M, G1, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
// D layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
std
::
array
<
void
*
,
1
>
{
nullptr
},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
Acc0ElementOp
{
1.
f
},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
View file @
8da05b38
...
...
@@ -27,7 +27,7 @@ using KernelTypes = ::testing::Types<
TYPED_TEST_SUITE
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
DISABLED_
Test_BF16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
Test_BF16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
Test_BF16_PadM
)
{
...
...
@@ -96,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddO)
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
DISABLED_
Bench_BF16_IrregularK
)
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
Bench_BF16_IrregularK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{{
256
,
256
,
160
,
160
,
1
,
16
},
{
256
,
64
,
160
,
64
,
1
,
16
},
...
...
@@ -109,7 +109,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF1
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
DISABLED_
Bench_BF16
)
TYPED_TEST
(
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16
,
Bench_BF16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
256
,
256
,
64
,
64
,
48
,
16
},
...
...
test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp
View file @
8da05b38
...
...
@@ -23,7 +23,7 @@ class TestElementwiseLayernorm : public ::testing::Test
{
// M, N
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
lengths
=
{
{
1
,
1
},
{
25
,
16
},
{
39
,
777
},
{
100
,
200
},
{
1024
,
1024
},
{
48
*
256
,
2048
}};
{
1
,
1
},
{
25
,
16
},
{
39
,
777
},
{
100
,
200
},
{
1024
,
1024
},
{
48
*
256
,
2048
}
,
{
4096
,
8192
}
};
for
(
auto
length
:
lengths
)
{
...
...
test/gemm_layernorm/CMakeLists.txt
0 → 100644
View file @
8da05b38
add_custom_target
(
test_gemm_layernorm
)
add_gtest_executable
(
test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp
)
target_link_libraries
(
test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance
)
add_dependencies
(
test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16
)
test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp
0 → 100644
View file @
8da05b38
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp"
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ck
::
index_t
;
template
<
typename
Tuple
>
class
TestGemmAddReluAddLayernorm
:
public
::
testing
::
Test
{
protected:
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
D0DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
D1DataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
EMeanVarDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
using
HDataType
=
std
::
tuple_element_t
<
8
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
9
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
10
,
Tuple
>
;
using
D0Layout
=
std
::
tuple_element_t
<
11
,
Tuple
>
;
using
D1Layout
=
std
::
tuple_element_t
<
12
,
Tuple
>
;
using
HLayout
=
std
::
tuple_element_t
<
13
,
Tuple
>
;
void
Run
()
{
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
lengths
=
{
{
1024
,
1024
,
1024
},
{
2048
,
640
,
640
},
{
1
,
1
,
1
}};
for
(
auto
length
:
lengths
)
{
int
M
=
length
[
0
];
int
N
=
length
[
1
];
int
K
=
length
[
2
];
int
StrideA
=
ck
::
is_same_v
<
ALayout
,
Row
>
?
K
:
M
;
int
StrideB
=
ck
::
is_same_v
<
BLayout
,
Row
>
?
N
:
K
;
int
StrideD0
=
0
;
int
StrideD1
=
ck
::
is_same_v
<
D1Layout
,
Row
>
?
N
:
M
;
int
StrideH
=
ck
::
is_same_v
<
HLayout
,
Row
>
?
N
:
M
;
bool
success
=
ck
::
profiler
::
profile_gemm_add_relu_add_layernorm_impl
<
ADataType
,
BDataType
,
AccDataType
,
D0DataType
,
D1DataType
,
EMeanVarDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
ALayout
,
BLayout
,
D0Layout
,
D1Layout
,
HLayout
>
(
true
,
1
,
false
,
false
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideD0
,
StrideD1
,
StrideH
);
EXPECT_TRUE
(
success
);
}
}
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
Row
,
Row
,
Row
,
Row
,
Row
>
,
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
Row
,
Col
,
Row
,
Row
,
Row
>
,
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
Col
,
Row
,
Row
,
Row
,
Row
>
,
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
Col
,
Col
,
Row
,
Row
,
Row
>>
;
TYPED_TEST_SUITE
(
TestGemmAddReluAddLayernorm
,
KernelTypes
);
TYPED_TEST
(
TestGemmAddReluAddLayernorm
,
Test_FP16
)
{
this
->
Run
();
}
test/normalization/CMakeLists.txt
View file @
8da05b38
add_custom_target
(
test_
layernorm
)
add_custom_target
(
test_
normalization
)
add_gtest_executable
(
test_layernorm2d_fp32 test_layernorm2d_fp32.cpp
)
add_gtest_executable
(
test_layernorm2d_fp16 test_layernorm2d_fp16.cpp
)
add_gtest_executable
(
test_groupnorm_fp16 test_groupnorm_fp16.cpp
)
add_gtest_executable
(
test_groupnorm_fp32 test_groupnorm_fp32.cpp
)
target_link_libraries
(
test_layernorm2d_fp32 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_layernorm2d_fp16 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_groupnorm_fp16 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_groupnorm_fp32 PRIVATE utility device_normalization_instance
)
add_dependencies
(
test_
layernorm
test_layernorm2d_fp32
)
add_dependencies
(
test_
layernorm
test_layernorm2d_fp16
)
add_dependencies
(
test_
layernorm
test_groupnorm_fp16
)
add_dependencies
(
test_
layernorm
test_groupnorm_fp32
)
add_dependencies
(
test_
normalization
test_layernorm2d_fp32
)
add_dependencies
(
test_
normalization
test_layernorm2d_fp16
)
add_dependencies
(
test_
normalization
test_groupnorm_fp16
)
add_dependencies
(
test_
normalization
test_groupnorm_fp32
)
test/normalization/test_groupnorm_fp16.cpp
View file @
8da05b38
...
...
@@ -15,7 +15,7 @@ class TestGroupnorm : public ::testing::Test
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
Acc
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
Compute
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
void
Run
()
...
...
@@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test
ck
::
profiler
::
profile_groupnorm_impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
>
(
true
,
2
,
false
,
false
,
length
);
EXPECT_TRUE
(
success
);
}
...
...
@@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test
};
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType,
Acc
DataType, YDataType>
// XDataType, GammaDataType, BetaDataType,
Compute
DataType, YDataType>
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
...
...
test/normalization/test_groupnorm_fp32.cpp
View file @
8da05b38
...
...
@@ -15,7 +15,7 @@ class TestGroupnorm : public ::testing::Test
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
Acc
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
Compute
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
void
Run
()
...
...
@@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test
ck
::
profiler
::
profile_groupnorm_impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
>
(
true
,
2
,
false
,
false
,
length
);
EXPECT_TRUE
(
success
);
}
...
...
@@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test
};
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType,
Acc
DataType, YDataType>
// XDataType, GammaDataType, BetaDataType,
Compute
DataType, YDataType>
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
...
...
test/normalization/test_layernorm2d_fp16.cpp
View file @
8da05b38
...
...
@@ -15,7 +15,7 @@ class TestLayernorm2d : public ::testing::Test
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
Acc
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
Compute
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
void
Run
()
...
...
@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test
bool
success
=
ck
::
profiler
::
profile_layernorm_impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
,
2
>
(
true
,
2
,
false
,
false
,
length
);
EXPECT_TRUE
(
success
);
...
...
@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test
};
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType,
Acc
DataType, YDataType>
// XDataType, GammaDataType, BetaDataType,
Compute
DataType, YDataType>
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>>
;
TYPED_TEST_SUITE
(
TestLayernorm2d
,
KernelTypes
);
...
...
test/normalization/test_layernorm2d_fp32.cpp
View file @
8da05b38
...
...
@@ -15,7 +15,7 @@ class TestLayernorm2d : public ::testing::Test
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
Acc
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
Compute
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
void
Run
()
...
...
@@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test
bool
success
=
ck
::
profiler
::
profile_layernorm_impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
YDataType
,
2
>
(
true
,
2
,
false
,
false
,
length
);
EXPECT_TRUE
(
success
);
...
...
@@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test
};
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType,
Acc
DataType, YDataType>
// XDataType, GammaDataType, BetaDataType,
Compute
DataType, YDataType>
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>>
;
TYPED_TEST_SUITE
(
TestLayernorm2d
,
KernelTypes
);
...
...
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment