Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5aa3c344
Unverified
Commit
5aa3c344
authored
Oct 05, 2022
by
rocking5566
Committed by
GitHub
Oct 05, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
7fefc966
9d8f834a
Changes
129
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
405 additions
and
74 deletions
+405
-74
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
+193
-0
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
+21
-8
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
...gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
+15
-9
test/layernorm/CMakeLists.txt
test/layernorm/CMakeLists.txt
+13
-6
test/layernorm/test_groupnorm_fp16.cpp
test/layernorm/test_groupnorm_fp16.cpp
+56
-0
test/layernorm/test_groupnorm_fp32.cpp
test/layernorm/test_groupnorm_fp32.cpp
+56
-0
test/layernorm/test_layernorm2d_fp16.cpp
test/layernorm/test_layernorm2d_fp16.cpp
+13
-13
test/layernorm/test_layernorm2d_fp32.cpp
test/layernorm/test_layernorm2d_fp32.cpp
+13
-13
test/layernorm/test_layernorm2d_util.hpp
test/layernorm/test_layernorm2d_util.hpp
+25
-25
No files found.
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
using
F16
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
typename
Tuple
>
struct
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
:
public
::
testing
::
Test
{
using
ADataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
B0DataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
B1DataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
B0Layout
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CPermuteNumDims_G_M_O
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
6
,
4
},
{
256
,
256
,
128
,
128
,
4
,
6
},
{
512
,
512
,
64
,
64
,
3
,
2
},
{
512
,
512
,
128
,
128
,
2
,
3
},
{
1024
,
1024
,
64
,
64
,
3
,
1
},
{
1024
,
1024
,
128
,
128
,
1
,
1
},
};
bool
bench_
=
false
;
bool
verify_
=
true
;
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
<
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_O
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
EXPECT_TRUE
(
pass
);
}
void
Run
()
{
for
(
auto
lengths
:
this
->
lengths_
)
{
int
M
=
lengths
[
0
];
int
N
=
lengths
[
1
];
int
K
=
lengths
[
2
];
int
O
=
lengths
[
3
];
int
G0
=
lengths
[
4
];
int
G1
=
lengths
[
5
];
this
->
RunSingle
(
M
,
N
,
K
,
O
,
G0
,
G1
);
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
CPermuteNumDims_G_M_O
=
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
F16
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_O
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
true
>
;
// Masking
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
{
0
,
0
,
M
,
O
},
// gs ms ns lengths
{
0
,
O
,
0
,
1
},
// gs ms ns strides
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
Scale
{
1.
f
},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp
View file @
5aa3c344
...
...
@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmSoftmaxGemmFP16
,
DISABLED_Bench_FP16_IrregularK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{{
256
,
256
,
160
,
160
,
16
},
{
256
,
64
,
160
,
64
,
16
},
{
1024
,
1024
,
80
,
80
,
16
},
{
1024
,
64
,
80
,
64
,
16
},
{
4096
,
4096
,
40
,
40
,
16
},
{
4096
,
64
,
40
,
64
,
16
}};
this
->
bench_
=
true
;
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
// TODO: enable KPadding tests when it is implemented
...
...
@@ -118,19 +131,19 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
//
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
...
...
test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp
View file @
5aa3c344
...
...
@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
using
B1Layout
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
7
,
Tuple
>
;
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{
{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
};
std
::
vector
<
std
::
vector
<
int
>>
lengths_
=
{{
256
,
256
,
64
,
64
,
4
},
{
256
,
256
,
128
,
128
,
4
},
{
512
,
512
,
64
,
64
,
2
},
{
512
,
512
,
128
,
128
,
2
},
{
1024
,
1024
,
64
,
64
,
1
},
{
1024
,
1024
,
128
,
128
,
1
},
{
256
,
256
,
160
,
160
,
4
},
{
256
,
64
,
160
,
64
,
4
},
{
1024
,
1024
,
80
,
80
,
2
},
{
1024
,
64
,
80
,
64
,
2
},
{
4096
,
4096
,
40
,
40
,
1
},
{
4096
,
64
,
40
,
64
,
1
}};
bool
bench_
=
false
;
bool
verify_
=
true
;
...
...
@@ -155,7 +160,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
false
>
;
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
...
...
test/layernorm/CMakeLists.txt
View file @
5aa3c344
add_custom_target
(
test_layernorm
)
add_gtest_executable
(
test_layernorm_fp32 test_layernorm_fp32.cpp
)
add_gtest_executable
(
test_layernorm_fp16 test_layernorm_fp16.cpp
)
add_gtest_executable
(
test_layernorm2d_fp32 test_layernorm2d_fp32.cpp
)
add_gtest_executable
(
test_layernorm2d_fp16 test_layernorm2d_fp16.cpp
)
add_gtest_executable
(
test_groupnorm_fp16 test_groupnorm_fp16.cpp
)
add_gtest_executable
(
test_groupnorm_fp32 test_groupnorm_fp32.cpp
)
target_link_libraries
(
test_layernorm_fp32 PRIVATE utility
)
target_link_libraries
(
test_layernorm_fp16 PRIVATE utility
)
target_link_libraries
(
test_layernorm2d_fp32 PRIVATE utility
)
target_link_libraries
(
test_layernorm2d_fp16 PRIVATE utility
)
target_link_libraries
(
test_groupnorm_fp16 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_groupnorm_fp32 PRIVATE utility device_normalization_instance
)
add_dependencies
(
test_layernorm test_layernorm2d_fp32
)
add_dependencies
(
test_layernorm test_layernorm2d_fp16
)
add_dependencies
(
test_layernorm test_groupnorm_fp16
)
add_dependencies
(
test_layernorm test_groupnorm_fp32
)
add_dependencies
(
test_layernorm test_layernorm_fp32
)
add_dependencies
(
test_layernorm test_layernorm_fp16
)
test/layernorm/test_groupnorm_fp16.cpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/include/profile_groupnorm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ck
::
index_t
;
template
<
typename
Tuple
>
class
TestGroupnorm
:
public
::
testing
::
Test
{
protected:
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
void
Run
()
{
// N, H, W, G, C
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
lengths
=
{{
1
,
1
,
1
,
1
,
1
},
{
1
,
2
,
3
,
4
,
5
},
{
256
,
9
,
9
,
9
,
9
},
{
1
,
64
,
64
,
32
,
10
},
{
1
,
32
,
32
,
32
,
20
},
{
1
,
16
,
16
,
32
,
40
}};
for
(
auto
length
:
lengths
)
{
bool
success
=
ck
::
profiler
::
profile_groupnorm_impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
>
(
true
,
2
,
false
,
false
,
length
);
EXPECT_TRUE
(
success
);
}
}
};
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>
,
std
::
tuple
<
F16
,
F16
,
F16
,
F32
,
F16
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
TYPED_TEST
(
TestGroupnorm
,
Test_FP16
)
{
this
->
Run
();
}
test/layernorm/test_groupnorm_fp32.cpp
0 → 100644
View file @
5aa3c344
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/include/profile_groupnorm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ck
::
index_t
;
template
<
typename
Tuple
>
class
TestGroupnorm
:
public
::
testing
::
Test
{
protected:
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
GammaDataType
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BetaDataType
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
YDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
void
Run
()
{
// N, H, W, G, C
std
::
vector
<
std
::
vector
<
ck
::
index_t
>>
lengths
=
{{
1
,
1
,
1
,
1
,
1
},
{
1
,
2
,
3
,
4
,
5
},
{
256
,
9
,
9
,
9
,
9
},
{
1
,
64
,
64
,
32
,
10
},
{
1
,
32
,
32
,
32
,
20
},
{
1
,
16
,
16
,
32
,
40
}};
for
(
auto
length
:
lengths
)
{
bool
success
=
ck
::
profiler
::
profile_groupnorm_impl
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
>
(
true
,
2
,
false
,
false
,
length
);
EXPECT_TRUE
(
success
);
}
}
};
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType>
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
>>
;
TYPED_TEST_SUITE
(
TestGroupnorm
,
KernelTypes
);
TYPED_TEST
(
TestGroupnorm
,
Test_FP32
)
{
this
->
Run
();
}
test/layernorm/test_layernorm_fp16.cpp
→
test/layernorm/test_layernorm
2d
_fp16.cpp
View file @
5aa3c344
...
...
@@ -2,28 +2,28 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_layernorm_util.hpp"
#include "test_layernorm
2d
_util.hpp"
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
template
<
typename
Tuple
>
class
TestLayernormFP16
:
public
ck
::
TestLayernorm
<
Tuple
>
class
TestLayernorm
2d
FP16
:
public
ck
::
TestLayernorm
2d
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>
,
I
<
8
>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize,
GammaSrcVectorDim
, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
half_t
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
8
>
,
I
<
8
>>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestLayernormFP16
,
KernelTypes
);
TYPED_TEST
(
TestLayernormFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST_SUITE
(
TestLayernorm
2d
FP16
,
KernelTypes
);
TYPED_TEST
(
TestLayernorm
2d
FP16
,
Test_FP16
)
{
this
->
Run
();
}
test/layernorm/test_layernorm_fp32.cpp
→
test/layernorm/test_layernorm
2d
_fp32.cpp
View file @
5aa3c344
...
...
@@ -2,28 +2,28 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_layernorm_util.hpp"
#include "test_layernorm
2d
_util.hpp"
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
template
<
typename
Tuple
>
class
TestLayernormFP32
:
public
ck
::
TestLayernorm
<
Tuple
>
class
TestLayernorm
2d
FP32
:
public
ck
::
TestLayernorm
2d
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>
,
I
<
4
>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize,
GammaSrcVectorDim
, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize, YDstVectorSize>
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
8
>
,
I
<
32
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
4
>
,
I
<
64
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
128
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
,
std
::
tuple
<
float
,
float
,
float
,
float
,
float
,
I
<
2
>
,
I
<
1
>
,
I
<
256
>
,
I
<
1
>
,
I
<
256
>
,
I
<
2
>
,
I
<
8
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
1
>
,
I
<
4
>
,
I
<
4
>>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestLayernormFP32
,
KernelTypes
);
TYPED_TEST
(
TestLayernormFP32
,
Test_FP32
)
{
this
->
Run
();
}
TYPED_TEST_SUITE
(
TestLayernorm
2d
FP32
,
KernelTypes
);
TYPED_TEST
(
TestLayernorm
2d
FP32
,
Test_FP32
)
{
this
->
Run
();
}
test/layernorm/test_layernorm_util.hpp
→
test/layernorm/test_layernorm
2d
_util.hpp
View file @
5aa3c344
...
...
@@ -31,7 +31,7 @@ std::string serialize_range(const Range& range)
}
template
<
typename
Tuple
>
class
TestLayernorm
:
public
::
testing
::
Test
class
TestLayernorm
2d
:
public
::
testing
::
Test
{
protected:
using
XDataType
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
...
...
@@ -48,9 +48,11 @@ class TestLayernorm : public ::testing::Test
static
constexpr
index_t
KThreadSliceSize
=
std
::
tuple_element_t
<
11
,
Tuple
>
{}.
value
;
static
constexpr
index_t
XYSrcVectorDim
=
std
::
tuple_element_t
<
12
,
Tuple
>
{}.
value
;
static
constexpr
index_t
XSrcVectorSize
=
std
::
tuple_element_t
<
13
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorSize
=
std
::
tuple_element_t
<
14
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorSize
=
std
::
tuple_element_t
<
15
,
Tuple
>
{}.
value
;
static
constexpr
index_t
YDstVectorSize
=
std
::
tuple_element_t
<
16
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorDim
=
std
::
tuple_element_t
<
14
,
Tuple
>
{}.
value
;
static
constexpr
index_t
GammaSrcVectorSize
=
std
::
tuple_element_t
<
15
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorDim
=
std
::
tuple_element_t
<
16
,
Tuple
>
{}.
value
;
static
constexpr
index_t
BetaSrcVectorSize
=
std
::
tuple_element_t
<
17
,
Tuple
>
{}.
value
;
static
constexpr
index_t
YDstVectorSize
=
std
::
tuple_element_t
<
18
,
Tuple
>
{}.
value
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
@@ -78,23 +80,24 @@ class TestLayernorm : public ::testing::Test
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorSize
>
;
TestLayernorm
()
:
ref_instance_invoker_
(
ReferenceInstance
{}.
MakeInvoker
())
{}
TestLayernorm
2d
()
:
ref_instance_invoker_
(
ReferenceInstance
{}.
MakeInvoker
())
{}
void
RunSingle
(
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
reduceDims
)
void
RunSingle
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
reduceDims
,
const
std
::
vector
<
index_t
>&
GammaLength
,
const
std
::
vector
<
index_t
>&
GammaStride
,
const
std
::
vector
<
index_t
>&
BetaLength
,
const
std
::
vector
<
index_t
>&
BetaStride
)
{
std
::
vector
<
index_t
>
reduceLength
(
reduceDims
.
size
());
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
{
reduceLength
[
i
]
=
lengths
[
reduceDims
[
i
]];
}
Tensor
<
XDataType
>
x
(
lengths
);
Tensor
<
GammaDataType
>
gamma
(
reduce
Length
);
Tensor
<
BetaDataType
>
beta
(
reduce
Length
);
Tensor
<
GammaDataType
>
gamma
(
Gamma
Length
);
Tensor
<
BetaDataType
>
beta
(
Beta
Length
);
Tensor
<
YDataType
>
y
(
lengths
);
Tensor
<
YDataType
>
y_ref
(
lengths
);
...
...
@@ -115,10 +118,8 @@ class TestLayernorm : public ::testing::Test
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
lengths
,
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
gamma
.
mDesc
.
GetStrides
().
begin
(),
gamma
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
beta
.
mDesc
.
GetStrides
().
begin
(),
beta
.
mDesc
.
GetStrides
().
end
()},
GammaStride
,
BetaStride
,
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
reduceDims
,
1e-4
,
...
...
@@ -163,17 +164,16 @@ class TestLayernorm : public ::testing::Test
void
Run
()
{
for
(
auto
length
:
this
->
lengths_
)
std
::
vector
<
std
::
vector
<
index_t
>>
lengths
=
{
{
4
,
256
},
{
8
,
511
},
{
9
,
1032
},
{
4
,
2048
},
{
1
,
8192
},
{
4000
,
2000
}};
for
(
auto
length
:
lengths
)
{
this
->
RunSingle
(
length
,
reduceDims_
[
0
]
);
this
->
RunSingle
(
length
,
{
1
},
{
length
[
1
]},
{
0
,
1
},
{
length
[
1
]},
{
0
,
1
}
);
}
}
std
::
vector
<
std
::
vector
<
index_t
>>
lengths_
=
{
{
4
,
256
},
{
8
,
511
},
{
9
,
1032
},
{
4
,
2048
},
{
1
,
8192
},
{
4000
,
2000
}};
std
::
vector
<
std
::
vector
<
index_t
>>
reduceDims_
=
{{
1
}};
typename
ReferenceInstance
::
Invoker
ref_instance_invoker_
;
};
}
// namespace ck
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment