Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8166d875
Commit
8166d875
authored
Sep 14, 2022
by
rocking
Browse files
Modify test, instance and client example
parent
12673f3f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
51 additions
and
49 deletions
+51
-49
client_example/05_layernorm/layernorm2d.cpp
client_example/05_layernorm/layernorm2d.cpp
+2
-2
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp
...tance/gpu/normalization/device_layernorm_f16_instance.cpp
+12
-12
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp
...tance/gpu/normalization/device_layernorm_f32_instance.cpp
+10
-10
test/layernorm/test_layernorm_fp16.cpp
test/layernorm/test_layernorm_fp16.cpp
+9
-9
test/layernorm/test_layernorm_fp32.cpp
test/layernorm/test_layernorm_fp32.cpp
+9
-9
test/layernorm/test_layernorm_util.hpp
test/layernorm/test_layernorm_util.hpp
+9
-7
No files found.
client_example/05_layernorm/layernorm2d.cpp
View file @
8166d875
...
...
@@ -81,8 +81,8 @@ int main(int argc, char* argv[])
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
{
Stride
,
1
},
// xStrides
{
1
},
// gammaStrides
{
1
},
// betaStrides
{
0
,
1
},
// gammaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
1
},
// reduceDims
1e-4
,
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f16_instance.cpp
View file @
8166d875
...
...
@@ -20,18 +20,18 @@ using Pass = ck::tensor_operation::element_wise::PassThrough;
template
<
index_t
Rank
,
index_t
Reduce
>
using
device_layernorm_f16_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVector
Size
, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
2
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
8
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
8
,
8
,
8
>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVector
Dim, GammaSrcVectorSize, BetaSrcVectorDim
, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
// fallback kernel
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
,
DeviceLayernormImpl
<
F16
,
F16
,
F16
,
F32
,
F16
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
8
,
1
,
8
,
1
,
8
,
8
>
// clang-format on
>
;
...
...
library/src/tensor_operation_instance/gpu/normalization/device_layernorm_f32_instance.cpp
View file @
8166d875
...
...
@@ -20,16 +20,16 @@ template <index_t Rank, index_t Reduce>
using
device_layernorm_f32_instances
=
std
::
tuple
<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
2
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
4
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
4
,
4
,
4
>
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
1
,
1
,
1
,
1
,
1
,
1
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
2
,
1
,
2
,
1
,
2
,
2
>
,
// fallback kernel
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
8
,
32
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
4
,
64
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
2
,
128
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
8
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
16
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
,
DeviceLayernormImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
Pass
,
Rank
,
Reduce
,
256
,
1
,
256
,
1
,
32
,
1
,
4
,
1
,
4
,
1
,
4
,
4
>
// clang-format on
>
;
...
...
test/layernorm/test_layernorm_fp16.cpp
View file @
8166d875
...
...
@@ -14,15 +14,15 @@ class TestLayernormFP16 : public ck::TestLayernorm<Tuple>
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize,
GammaSrcVectorDim
, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestLayernormFP16
,
KernelTypes
);
...
...
test/layernorm/test_layernorm_fp32.cpp
View file @
8166d875
...
...
@@ -14,15 +14,15 @@ class TestLayernormFP32 : public ck::TestLayernorm<Tuple>
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize,
GammaSrcVectorDim
, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestLayernormFP32
,
KernelTypes
);
...
...
test/layernorm/test_layernorm_util.hpp
View file @
8166d875
...
...
@@ -48,9 +48,11 @@ class TestLayernorm : public ::testing::Test
static
constexpr
index_t
KThreadSliceSize
=
std
::
tuple_element_t
<
11
,
Tuple
>
{}.
value
;
static
constexpr
index_t
XYSrcVectorDim
=
std
::
tuple_element_t
<
12
,
Tuple
>
{}.
value
;
static
constexpr
index_t
XSrcVectorSize
=
std
::
tuple_element_t
<
13
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorSize
=
std
::
tuple_element_t
<
14
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorSize
=
std
::
tuple_element_t
<
15
,
Tuple
>
{}.
value
;
static
constexpr
index_t
YDstVectorSize
=
std
::
tuple_element_t
<
16
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorDim
=
std
::
tuple_element_t
<
14
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorSize
=
std
::
tuple_element_t
<
15
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorDim
=
std
::
tuple_element_t
<
16
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorSize
=
std
::
tuple_element_t
<
17
,
Tuple
>
{}.
value
;
static
constexpr
index_t
YDstVectorSize
=
std
::
tuple_element_t
<
18
,
Tuple
>
{}.
value
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -78,7 +80,9 @@ class TestLayernorm : public ::testing::Test
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorSize
>
;
...
...
@@ -115,10 +119,8 @@ class TestLayernorm : public ::testing::Test
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
lengths
,
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
gamma
.
mDesc
.
GetStrides
().
begin
(),
gamma
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
beta
.
mDesc
.
GetStrides
().
begin
(),
beta
.
mDesc
.
GetStrides
().
end
()},
{
0
,
1
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
reduceDims
,
1e-4
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment