Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7d50244e
Unverified
Commit
7d50244e
authored
Oct 31, 2024
by
Illia Silin
Committed by
GitHub
Oct 31, 2024
Browse files
Merge pull request #209 from ROCm/andriy/merge_from_public
Update develop branch from public repository
parents
f221c2b0
d51701d4
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1232 additions
and
184 deletions
+1232
-184
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp
+12
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp
+14
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp
...stances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp
+14
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp
...nces/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp
+14
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp
+13
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp
...nces/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp
+12
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp
...nstances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp
+12
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp
...t/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp
+67
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh
example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh
+38
-0
example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh
...ple/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh
+31
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+4
-0
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+38
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+16
-3
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+24
-4
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+797
-152
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+6
-6
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+9
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+5
-0
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+104
-0
No files found.
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd x 3p
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
add_rmsnorm2d_rdquant_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
add_rmsnorm2d_rdquant_fwd_args
;
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
>
using
trait_
=
add_rmsnorm2d_rdquant_fwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
,
kSaveInvRms_
,
kTwoPass_
>
;
template
<
typename
Traits_
>
float
add_rmsnorm2d_rdquant_fwd_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
AddRmsnorm2dRdquantFwdPipelineProblem
<
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
ADataType
,
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
BDataType
,
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
GammaDataType
,
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
ComputeDataType
,
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
XDataType
,
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
YScaleDataType
,
typename
AddRmsnormRdquantTypeConfig
<
DataType
>::
QYDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveX
,
Traits_
::
kThreePass
>
;
using
OnePassPipeline
=
ck_tile
::
AddRmsnorm2dRdquantFwdPipelineOnePass
<
PipelineProblem
>
;
using
ThreePassPipeline
=
ck_tile
::
AddRmsnorm2dRdquantFwdPipelineThreePass
<
PipelineProblem
>
;
using
Pipeline
=
std
::
conditional_t
<
Traits_
::
kThreePass
,
ThreePassPipeline
,
OnePassPipeline
>
;
using
Kernel
=
ck_tile
::
AddRmsnorm2dRdquantFwd
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh
0 → 100755
View file @
7d50244e
# run from top of ck folder
EXE
=
build/bin/tile_add_rmsnorm2d_rdquant_fwd
$EXE
-m
=
1
-n
=
1
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
bf16
-repeat
=
1000
$EXE
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
$EXE
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
\ No newline at end of file
example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh
0 → 100755
View file @
7d50244e
#!/bin/sh
# call from top of CK folder
EXE
=
./build/bin/tile_add_rmsnorm2d_rdquant_fwd
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-m
=
8
-n
=
1501
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
1826
$EXE
-prec
=
$pr_i
-m
=
5
-n
=
2040
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
2734
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
3182
$EXE
-prec
=
$pr_i
-m
=
9
-n
=
4096
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
8192
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
10547
$EXE
-prec
=
$pr_i
-m
=
3
-n
=
17134
done
example/ck_tile/CMakeLists.txt
View file @
7d50244e
...
...
@@ -7,3 +7,7 @@ add_subdirectory(02_layernorm2d)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
add_subdirectory
(
05_reduce
)
add_subdirectory
(
06_permute
)
add_subdirectory
(
09_topk_softmax
)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
include/ck/host_utility/flush_cache.hpp
View file @
7d50244e
...
...
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args
...
args
)
{
#if CK_TIME_KERNEL
#define MEDIAN
1
#define MEDIAN
0
if
(
stream_config
.
time_kernel_
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
...
...
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else
float
total_time
=
0
;
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
if
constexpr
(
!
TimePreprocess
)
...
...
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess
();
}
hipEvent_t
start
,
stop
;
//
hipEvent_t start, stop;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
//
hip_check_error(hipEventCreate(&start));
//
hip_check_error(hipEventCreate(&stop));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
//
hip_check_error(hipDeviceSynchronize());
//
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
{
...
...
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipGetLastError
());
// end real kernel
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
//
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
//
hip_check_error(hipEventSynchronize(stop));
//
float cur_time = 0;
//
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
//
#if MEDIAN
//
times.insert(cur_time);
//
#else
//
total_time += cur_time;
//
#endif
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
//
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
}
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
#if MEDIAN
auto
mid
=
times
.
begin
();
...
...
@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
(
*
mid
+
*
mid_next
)
/
2
;
}
#else
return
total_time
/
nrepeat
;
// return total_time / nrepeat;
hipDeviceProp_t
deviceProps
;
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
float
preprocess_offset
=
deviceProps
.
multiProcessorCount
==
80
?
0.005
:
0.01
;
return
(
total_time
-
preprocess_offset
*
nrepeat
)
/
nrepeat
;
#endif
}
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
7d50244e
...
...
@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
7d50244e
...
...
@@ -85,9 +85,9 @@ __global__ void
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -121,6 +121,19 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
a_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
b_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
CDEElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
cde_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
7d50244e
...
...
@@ -272,6 +272,26 @@ struct MultiplyMultiply
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
ck
::
half_t
&
d0
,
const
ck
::
half_t
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
int
,
float
,
float
>
(
ck
::
bhalf_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
struct
MultiplyAddFastGelu
...
...
@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu
const
float
&
d1
)
const
{
const
float
x
=
c
*
alpha1_
+
alpha2_
*
d0
+
d1
;
Relu
{}.
template
operator
()
<
float
>(
e
,
x
)
;
e
=
x
>
0
?
x
:
0
;
}
template
<
>
...
...
@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
half_t
>
(
result
);
}
...
...
@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu
type_convert
<
float
>
(
d1
);
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
bhalf_t
>
(
result
);
}
...
...
@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
d0
+
d1
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
int8_t
>
(
result
);
}
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
7d50244e
...
...
@@ -7,11 +7,38 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
#include <cassert>
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
{
public:
__host__
__device__
~
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
virtual
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
=
0
;
};
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
...
...
@@ -25,17 +52,30 @@ struct PassThroughPack2
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
struct
PassThrough
final
:
public
UnaryOpBase
{
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
PassThrough
&
)
=
default
;
__host__
__device__
constexpr
PassThrough
(
PassThrough
&&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
const
PassThrough
&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
PassThrough
&&
)
=
default
;
__host__
__device__
~
PassThrough
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
x
;
}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
...
...
@@ -48,36 +88,12 @@ struct PassThrough
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
...
...
@@ -102,12 +118,6 @@ struct PassThrough
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
{
...
...
@@ -407,20 +417,45 @@ struct UnarySquare
};
};
struct
UnaryAbs
struct
UnaryAbs
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
UnaryAbs
()
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
~
UnaryAbs
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
};
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
template
<
>
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
...
...
@@ -439,20 +474,41 @@ struct UnarySqrt
};
};
struct
Relu
struct
Relu
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Relu
()
=
default
;
__host__
__device__
constexpr
Relu
(
const
Relu
&
)
=
default
;
__host__
__device__
constexpr
Relu
(
Relu
&&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
const
Relu
&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
Relu
&&
)
=
default
;
__host__
__device__
~
Relu
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
...
...
@@ -599,18 +655,52 @@ struct Gelu
}
};
struct
Sigmoid
struct
Sigmoid
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Sigmoid
()
=
default
;
__host__
__device__
constexpr
Sigmoid
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
constexpr
Sigmoid
(
Sigmoid
&&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
Sigmoid
&&
)
=
default
;
__host__
__device__
~
Sigmoid
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
one
/
(
one
+
ck
::
math
::
exp
(
x_f32
));
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
}
};
struct
Silu
...
...
@@ -626,18 +716,44 @@ struct Silu
};
};
struct
TanH
struct
TanH
final
:
public
UnaryOpBase
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
TanH
()
=
default
;
__host__
__device__
constexpr
TanH
(
const
TanH
&
)
=
default
;
__host__
__device__
constexpr
TanH
(
TanH
&&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
const
TanH
&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
TanH
&&
)
=
default
;
__host__
__device__
~
TanH
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
};
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
};
struct
ACos
...
...
@@ -878,138 +994,418 @@ struct Rcp
};
};
struct
Swish
struct
Swish
final
:
public
UnaryOpBase
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
constexpr
Swish
(
const
Swish
&
)
=
default
;
__host__
__device__
constexpr
Swish
(
Swish
&&
)
=
default
;
__host__
__device__
~
Swish
()
=
default
;
__host__
__device__
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
double
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int32_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
ck
::
half_t
>::
value
,
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
ck
::
half_t
>::
value
,
is_same
<
Y
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
const
float
beta_
;
}
};
struct
SoftRelu
struct
SoftRelu
final
:
public
UnaryOpBase
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
SoftRelu
(
const
SoftRelu
&
)
=
default
;
__host__
__device__
constexpr
SoftRelu
(
SoftRelu
&&
)
=
default
;
__host__
__device__
~
SoftRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
SoftRelu
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
struct
Power
final
:
public
UnaryOpBase
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
__host__
__device__
constexpr
Power
(
const
Power
&
)
=
default
;
__host__
__device__
constexpr
Power
(
Power
&&
)
=
default
;
__host__
__device__
~
Power
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
)
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
__host__
__device__
float
get_gamma
()
const
{
return
gamma_
;
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
float
casted_gamma
=
type_convert
<
float
>
(
gamma_
);
float
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
double
casted_gamma
=
type_convert
<
double
>
(
gamma_
);
double
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
int32_t
casted_gamma
=
type_convert
<
int32_t
>
(
gamma_
);
int32_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
int8_t
casted_gamma
=
type_convert
<
int8_t
>
(
gamma_
);
int8_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
half_t
casted_gamma
=
type_convert
<
half_t
>
(
gamma_
);
half_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
bhalf_t
casted_gamma
=
type_convert
<
bhalf_t
>
(
gamma_
);
bhalf_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
};
struct
ClippedRelu
struct
ClippedRelu
final
:
public
UnaryOpBase
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
__host__
__device__
constexpr
ClippedRelu
(
const
ClippedRelu
&
)
=
default
;
__host__
__device__
constexpr
ClippedRelu
(
ClippedRelu
&&
)
=
default
;
__host__
__device__
~
ClippedRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
)
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
alpha_
;
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
};
struct
LeakyRelu
struct
LeakyRelu
final
:
public
UnaryOpBase
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
LeakyRelu
(
const
LeakyRelu
&
)
=
default
;
__host__
__device__
constexpr
LeakyRelu
(
LeakyRelu
&&
)
=
default
;
__host__
__device__
~
LeakyRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
LeakyRelu
(
float
alpha
=
0.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()([[
maybe_unused
]]
bhalf_t
&
y
,
[[
maybe_unused
]]
const
bhalf_t
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
struct
Elu
final
:
public
UnaryOpBase
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Elu
(
const
Elu
&
)
=
default
;
__host__
__device__
constexpr
Elu
(
Elu
&&
)
=
default
;
__host__
__device__
~
Elu
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
const
float
alpha_
;
};
struct
Logistic
struct
Logistic
final
:
public
UnaryOpBase
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Logistic
(
const
Logistic
&
)
=
default
;
__host__
__device__
constexpr
Logistic
(
Logistic
&&
)
=
default
;
__host__
__device__
~
Logistic
()
=
default
;
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
Logistic
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
...
...
@@ -1074,7 +1470,7 @@ struct ConvScaleRelu
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}
.
template
operator
()
<
float
>
(
x
,
c
*
scale_in_
*
scale_wei_
);
Relu
{}(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
...
...
@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
struct
DynamicUnaryOp
{
DynamicUnaryOp
&
operator
=
(
const
DynamicUnaryOp
&
other
)
{
if
(
this
!=
&
other
)
{
unary_op_ptr_
=
other
.
unary_op_ptr_
;
unary_op_type_
=
other
.
unary_op_type_
;
}
return
*
this
;
}
__host__
__device__
DynamicUnaryOp
()
=
delete
;
__host__
__device__
DynamicUnaryOp
(
const
Swish
&
swish
)
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Swish
&&
swish
)
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&&
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&
logistic
)
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&&
logistic
)
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&
softrelu
)
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&&
softrelu
)
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&
pow
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&&
pow
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&
clippedrelu
)
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&&
clippedrelu
)
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&
leakyrelu
)
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&&
leakyrelu
)
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&
elu
)
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&&
elu
)
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
:
unary_op_type_
(
dynamic_op
.
unary_op_type_
),
unary_op_ptr_
(
dynamic_op
.
unary_op_ptr_
),
alpha
(
dynamic_op
.
alpha
),
beta
(
dynamic_op
.
beta
),
gamma
(
dynamic_op
.
gamma
)
{
}
__host__
__device__
~
DynamicUnaryOp
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
delete
static_cast
<
Swish
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
delete
static_cast
<
Sigmoid
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
PassThrough
):
delete
static_cast
<
PassThrough
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Logistic
):
delete
static_cast
<
Logistic
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
TanH
):
delete
static_cast
<
TanH
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Relu
):
delete
static_cast
<
Relu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
delete
static_cast
<
SoftRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
delete
static_cast
<
UnaryAbs
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Power
):
delete
static_cast
<
Power
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
delete
static_cast
<
ClippedRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
delete
static_cast
<
LeakyRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Elu
):
delete
static_cast
<
Elu
*>
(
unary_op_ptr_
);
break
;
default:
break
;
}
}
__device__
void
InitUnaryOpPtrOnDevice
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
unary_op_ptr_
=
new
Swish
(
beta
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
unary_op_ptr_
=
new
Sigmoid
;
break
;
case
(
UnaryOpType
::
PassThrough
):
unary_op_ptr_
=
new
PassThrough
;
break
;
case
(
UnaryOpType
::
Logistic
):
unary_op_ptr_
=
new
Logistic
(
alpha
);
break
;
case
(
UnaryOpType
::
TanH
):
unary_op_ptr_
=
new
TanH
;
break
;
case
(
UnaryOpType
::
Relu
):
unary_op_ptr_
=
new
Relu
;
break
;
case
(
UnaryOpType
::
SoftRelu
):
unary_op_ptr_
=
new
SoftRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
unary_op_ptr_
=
new
UnaryAbs
;
break
;
case
(
UnaryOpType
::
Power
):
unary_op_ptr_
=
new
Power
(
alpha
,
beta
,
gamma
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
unary_op_ptr_
=
new
ClippedRelu
(
alpha
,
beta
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
unary_op_ptr_
=
new
LeakyRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
Elu
):
unary_op_ptr_
=
new
Elu
(
alpha
);
break
;
default:
unary_op_ptr_
=
nullptr
;
break
;
}
}
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
unary_op_ptr_
->
operator
()(
y
,
x
);
}
template
<
typename
Y
,
typename
X
>
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
Swish
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
Sigmoid
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
PassThrough
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
Logistic
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
TanH
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
Relu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
SoftRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
UnaryAbs
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
Power
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
ClippedRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
LeakyRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
Elu
{}.
operator
()(
y
,
x
);
break
;
default:
break
;
}
}
template
<
typename
X
,
typename
Y
>
__device__
__host__
constexpr
void
isSupported
()
const
{
static_assert
(
std
::
is_same
<
X
,
Y
>::
value
,
"X and Y must be of the same type"
);
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
bhalf_t
>::
value
||
is_same
<
X
,
half_t
>::
value
||
is_same
<
X
,
int32_t
>::
value
||
is_same
<
X
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
}
private:
enum
class
UnaryOpType
{
Swish
,
Sigmoid
,
PassThrough
,
Logistic
,
TanH
,
Relu
,
SoftRelu
,
UnaryAbs
,
Power
,
ClippedRelu
,
LeakyRelu
,
Elu
};
public:
UnaryOpType
unary_op_type_
;
UnaryOpBase
*
unary_op_ptr_
=
nullptr
;
float
alpha
;
float
beta
;
float
gamma
;
};
#pragma clang diagnostic pop
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/utility/amd_xdlops.hpp
View file @
7d50244e
...
...
@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
__builtin_amdgcn_mfma_i32_16x16x32
_
i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
};
...
...
include/ck/utility/data_type.hpp
View file @
7d50244e
...
...
@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t>
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
}
// namespace ck
include/ck_tile/core.hpp
View file @
7d50244e
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
...
...
@@ -24,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
...
...
@@ -49,14 +51,17 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
7d50244e
...
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate
,
xor_t
,
offset
,
indexing
,
};
template
<
index_t
NDimLow
,
index_t
NDimUp
>
...
...
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
}
};
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
struct
indexing
:
public
base_transform
<
1
,
1
>
{
static
constexpr
index_t
NDimUp
=
1
;
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
IndexingAdaptor
iadaptor_
;
CK_TILE_HOST_DEVICE
constexpr
indexing
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing
(
const
UpLength
&
up_length
,
const
IndexingAdaptor
&
iadaptor
)
:
up_lengths_
{
make_tuple
(
up_length
)},
iadaptor_
{
iadaptor
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
indexing
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
calculate_lower_index
(
idx_low
,
idx_up
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
IndexingAdaptor
::
is_known_at_compile_time
();
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"embed{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
}
};
//*******************************************************************************************************
template
<
typename
LowLength
>
...
...
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
}
}
// namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace
ck_tile
{
template
<
typename
UpLength
,
typename
Indices
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform
(
const
UpLength
&
up_lengths
,
const
Indices
&
indices
)
{
// by default we use the simplest one
return
indexing
<
UpLength
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>>
{
up_lengths
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>
{
indices
}};
}
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform_with_adaptor
(
const
UpLength
&
up_lengths
,
const
IndexingAdaptor
&
iadaptor
)
{
return
indexing
<
UpLength
,
IndexingAdaptor
>
{
up_lengths
,
iadaptor
};
}
}
// namespace ck_tile
Prev
1
…
3
4
5
6
7
8
9
10
11
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment