Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dec32dc6
Commit
dec32dc6
authored
Jan 31, 2025
by
ThomasNing
Browse files
Finish the feature and merge with develop on the computeV2
parents
71352c44
c5fff071
Changes
215
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
870 additions
and
795 deletions
+870
-795
example/ck_tile/35_batched_transpose/batched_transpose_example.cpp
...k_tile/35_batched_transpose/batched_transpose_example.cpp
+261
-0
example/ck_tile/35_batched_transpose/batched_transpose_example.hpp
...k_tile/35_batched_transpose/batched_transpose_example.hpp
+25
-0
example/ck_tile/35_batched_transpose/script/smoke_test.sh
example/ck_tile/35_batched_transpose/script/smoke_test.sh
+11
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/ck.hpp
include/ck/ck.hpp
+17
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+1
-14
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+213
-687
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+1
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+0
-2
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+17
-6
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+40
-63
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+2
-1
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
+210
-0
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+51
-6
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+15
-3
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+1
-1
No files found.
example/ck_tile/35_batched_transpose/batched_transpose_example.cpp
0 → 100644
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "batched_transpose_example.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 4);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto m =
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
std::cout << m;
if(v != len[3] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
}
}
std::cout << "]" << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-3
;
double
atol
=
1e-3
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
fp8_t
>
(
std
::
string
init_method
)
{
if
(
init_method
==
"ui"
||
init_method
==
"ni"
)
{
unsigned
max_rounding_point_distance
=
0
;
double
atol
=
2e-3
;
return
ck_tile
::
make_tuple
(
max_rounding_point_distance
,
atol
);
}
else
{
unsigned
max_rounding_point_distance
=
1
;
double
atol
=
0.0625
;
return
ck_tile
::
make_tuple
(
max_rounding_point_distance
,
atol
);
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"v"
,
"1"
,
"whether do CPU validation or not"
)
.
insert
(
"pr"
,
"fp16"
,
"input data type. fp16/fp32 (representing 8/16/32 bit data)"
)
.
insert
(
"N"
,
"2"
,
"input batch size. "
)
.
insert
(
"C"
,
"16"
,
"input channel size."
)
.
insert
(
"H"
,
"1"
,
"input height size."
)
.
insert
(
"W"
,
"16"
,
"input width size. "
)
.
insert
(
"layout_in"
,
"NCHW"
,
"input tensor data layout - NCHW by default"
)
.
insert
(
"layout_out"
,
"NHWC"
,
"output tensor data layout - NHWC by default "
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"t to 1 will print kernel name"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
Type
>
bool
run_batched_transpose
(
ck_tile
::
ArgParser
args
)
{
int
validate
=
args
.
get_int
(
"v"
);
std
::
string
prec
=
args
.
get_str
(
"pr"
);
int
N
=
args
.
get_int
(
"N"
);
int
C
=
args
.
get_int
(
"C"
);
int
H
=
args
.
get_int
(
"H"
);
int
W
=
args
.
get_int
(
"W"
);
std
::
string
layout_in
=
args
.
get_str
(
"layout_in"
);
std
::
string
layout_out
=
args
.
get_str
(
"layout_out"
);
int
seed
=
args
.
get_int
(
"seed"
);
int
dim_in
[
4
],
dim_out
[
4
];
int
stride_dim_in
[
4
],
stride_dim_out
[
4
];
bool
nchw2nhwc
=
layout_in
==
"NCHW"
&&
layout_out
==
"NHWC"
;
bool
nhwc2nchw
=
layout_in
==
"NHWC"
&&
layout_out
==
"NCHW"
;
assert
(
nchw2nhwc
!=
nhwc2nchw
);
(
void
)
nhwc2nchw
;
dim_in
[
0
]
=
N
;
dim_in
[
1
]
=
nchw2nhwc
?
C
:
H
;
dim_in
[
2
]
=
nchw2nhwc
?
H
:
W
;
dim_in
[
3
]
=
nchw2nhwc
?
W
:
C
;
dim_out
[
0
]
=
N
;
dim_out
[
1
]
=
nchw2nhwc
?
H
:
C
;
dim_out
[
2
]
=
nchw2nhwc
?
W
:
H
;
dim_out
[
3
]
=
nchw2nhwc
?
C
:
W
;
stride_dim_in
[
0
]
=
C
*
H
*
W
;
stride_dim_in
[
1
]
=
nchw2nhwc
?
H
*
W
:
C
*
W
;
stride_dim_in
[
2
]
=
nchw2nhwc
?
W
:
C
;
stride_dim_in
[
3
]
=
1
;
stride_dim_out
[
0
]
=
C
*
H
*
W
;
stride_dim_out
[
1
]
=
nchw2nhwc
?
C
*
W
:
H
*
W
;
stride_dim_out
[
2
]
=
nchw2nhwc
?
C
:
W
;
stride_dim_out
[
3
]
=
1
;
if
(
seed
<
0
)
{
seed
=
std
::
time
(
nullptr
);
}
ck_tile
::
HostTensor
<
Type
>
x_host
(
{
dim_in
[
0
],
dim_in
[
1
],
dim_in
[
2
],
dim_in
[
3
]},
{
stride_dim_in
[
0
],
stride_dim_in
[
1
],
stride_dim_in
[
2
],
stride_dim_in
[
3
]});
ck_tile
::
HostTensor
<
Type
>
y_host
(
{
dim_out
[
0
],
dim_out
[
1
],
dim_out
[
2
],
dim_out
[
3
]},
{
stride_dim_out
[
0
],
stride_dim_out
[
1
],
stride_dim_out
[
2
],
stride_dim_out
[
3
]});
ck_tile
::
FillUniformDistribution
<
Type
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
DeviceMem
x_dev
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_dev
(
y_host
.
get_element_space_size_in_bytes
());
x_dev
.
ToDevice
(
x_host
.
data
());
auto
trait
=
batched_transpose_trait
{
prec
,
layout_in
};
uint32_t
height
=
nchw2nhwc
?
C
:
H
*
W
;
uint32_t
width
=
nchw2nhwc
?
H
*
W
:
C
;
batched_transpose_kargs
karg
=
[
&
]()
{
batched_transpose_kargs
a_
;
a_
.
p_input
=
x_dev
.
GetDeviceBuffer
();
a_
.
p_output
=
y_dev
.
GetDeviceBuffer
();
a_
.
batch
=
N
;
a_
.
height
=
height
;
a_
.
width
=
width
;
return
a_
;
}();
ck_tile
::
stream_config
sc
{
nullptr
,
true
};
auto
ms
=
batched_transpose
(
trait
,
karg
,
sc
);
std
::
size_t
num_operations
=
N
*
C
*
H
*
(
W
-
1
);
std
::
size_t
num_bytes
=
N
*
C
*
H
*
W
*
sizeof
(
Type
);
float
ave_time
=
ms
*
1E-3
;
float
gb_per_sec
=
num_bytes
/
ms
*
1.E-6
;
float
tflops
=
static_cast
<
float
>
(
num_operations
)
/
ms
*
1.E-6
;
std
::
cout
<<
"Run Batched Transpose kernel with N="
<<
N
<<
", C="
<<
C
<<
", H="
<<
H
<<
", W="
<<
W
<<
", layout_in="
<<
layout_in
<<
", layout_out="
<<
layout_out
<<
" : "
<<
ms
<<
" ms ("
<<
ave_time
<<
" ave_time), "
<<
tflops
<<
" TFlops"
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
printf
(
"[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f
\n
"
,
prec
.
c_str
(),
N
,
C
,
H
,
W
,
layout_in
.
c_str
(),
ms
);
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
fflush
(
stdout
);
if
(
ms
<
0
)
{
return
false
;
}
y_dev
.
FromDevice
(
y_host
.
data
());
bool
rtn
=
true
;
if
(
validate
)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile
::
HostTensor
<
Type
>
y_ref
(
{
dim_out
[
0
],
dim_out
[
1
],
dim_out
[
2
],
dim_out
[
3
]},
{
stride_dim_out
[
0
],
stride_dim_out
[
1
],
stride_dim_out
[
2
],
stride_dim_out
[
3
]});
ck_tile
::
reference_batched_transpose
<
Type
>
(
x_host
,
y_ref
,
layout_in
,
layout_out
);
auto
[
rtol
,
atol
]
=
get_elimit
<
Type
>
(
""
);
rtn
&=
ck_tile
::
check_err
(
y_host
,
y_ref
,
std
::
string
(
"y Error: Incorrect results!"
),
rtol
,
atol
);
}
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
fflush
(
stdout
);
return
rtn
;
}
int
main
(
int
argc
,
char
**
argv
)
{
auto
[
result
,
args
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
prec
=
args
.
get_str
(
"pr"
);
bool
r
=
true
;
if
(
prec
.
compare
(
"fp32"
)
==
0
)
{
r
&=
run_batched_transpose
<
float
>
(
args
);
}
else
if
(
prec
.
compare
(
"fp16"
)
==
0
)
{
r
&=
run_batched_transpose
<
ck_tile
::
fp16_t
>
(
args
);
}
else
if
(
prec
.
compare
(
"bf16"
)
==
0
)
{
r
&=
run_batched_transpose
<
ck_tile
::
bf16_t
>
(
args
);
}
else
if
(
prec
.
compare
(
"int8"
)
==
0
)
{
r
&=
run_batched_transpose
<
ck_tile
::
int8_t
>
(
args
);
}
return
r
?
0
:
-
1
;
}
example/ck_tile/35_batched_transpose/batched_transpose_example.hpp
0 → 100644
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/batched_transpose.hpp"
#include <vector>
#include <string>
#pragma once
struct
batched_transpose_trait
{
std
::
string
type
;
std
::
string
layout
;
};
struct
batched_transpose_kargs
:
public
ck_tile
::
BatchedTransposeHostArgs
{
};
float
batched_transpose
(
batched_transpose_trait
t
,
batched_transpose_kargs
a
,
ck_tile
::
stream_config
s
);
example/ck_tile/35_batched_transpose/script/smoke_test.sh
0 → 100755
View file @
dec32dc6
#!/bin/sh
EXE
=
./build/bin/tile_example_batched_transpose
for
pr
in
"fp32"
"fp16"
"int8"
;
do
$EXE
-pr
=
$pr
-N
=
1
-C
=
32
-H
=
1
-W
=
32
-layout_in
=
'NCHW'
-layout_out
=
'NHWC'
$EXE
-pr
=
$pr
-N
=
2
-C
=
12
-H
=
1
-W
=
32
-layout_in
=
'NHWC'
-layout_out
=
'NCHW'
$EXE
-pr
=
$pr
-N
=
3
-C
=
1334
-H
=
1
-W
=
37
-layout_in
=
'NHWC'
-layout_out
=
'NCHW'
$EXE
-pr
=
$pr
-N
=
4
-C
=
27
-H
=
1
-W
=
32
-layout_in
=
'NCHW'
-layout_out
=
'NHWC'
$EXE
-pr
=
$pr
-N
=
5
-C
=
1234
-H
=
1
-W
=
12
-layout_in
=
'NCHW'
-layout_out
=
'NHWC'
done
example/ck_tile/CMakeLists.txt
View file @
dec32dc6
...
@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
...
@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
16_batched_gemm
)
add_subdirectory
(
16_batched_gemm
)
add_subdirectory
(
17_grouped_gemm
)
add_subdirectory
(
17_grouped_gemm
)
add_subdirectory
(
35_batched_transpose
)
include/ck/ck.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// to do: add various levels of logging with CK_LOG_LEVEL
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1
#define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter
// constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...
@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
#define CK_USE_SR_F8_CONVERSION 0
...
@@ -230,13 +235,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -230,13 +235,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// workaround: compiler issue on gfx908
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
#define CK_WORKAROUND_SWDEV_388832 1
// denorm test fix, required to work around dissue
// denorm test fix, necessary for gfx90a
#ifndef CK_WORKAROUND_DENORM_FIX
#ifndef CK_GFX90A_DENORM_WORKAROUND
#define CK_WORKAROUND_DENORM_FIX 0
#define CK_GFX90A_DENORM_WORKAROUND 0
#endif // CK_GFX90A_DENORM_WORKAROUND
// Enable only for gfx90a
#if defined(__gfx90a__)
#if CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 1
#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1
#else
#else
// enable only for gfx90a
#define CK_GFX90A_DENORM_WORKAROUND 0
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // gfx90a
#endif // CK_WORKAROUND_DENORM_FIX
// set flag to 1 to build deprecated instances
// set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1
#define CK_BUILD_DEPRECATED 1
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -121,19 +121,6 @@ __global__ void
...
@@ -121,19 +121,6 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
a_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
b_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
CDEElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
cde_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
{
{
AsPointer
p_as_grid_grp
;
AsPointer
p_as_grid_grp
;
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -247,32 +247,6 @@ struct DequantPack8
...
@@ -247,32 +247,6 @@ struct DequantPack8
constexpr
const
static
bool
is_pack8_invocable
=
true
;
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
{
public:
__host__
__device__
~
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
virtual
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
=
0
;
};
struct
PassThroughPack2
struct
PassThroughPack2
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -304,27 +278,8 @@ struct PassThroughPack2
...
@@ -304,27 +278,8 @@ struct PassThroughPack2
constexpr
const
static
bool
is_pack2_invocable
=
true
;
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
};
struct
PassThrough
final
:
public
UnaryOpBase
struct
PassThrough
{
{
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
PassThrough
&
)
=
default
;
__host__
__device__
constexpr
PassThrough
(
PassThrough
&&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
const
PassThrough
&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
PassThrough
&&
)
=
default
;
__host__
__device__
~
PassThrough
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
x
;
}
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
...
@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase
...
@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase
y
=
x
;
y
=
x
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
{
...
@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase
...
@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase
y
=
type_convert
<
double
>
(
x
);
y
=
type_convert
<
double
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
{
y
=
type_convert
<
half_t
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
{
...
@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase
...
@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase
y
=
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
{
{
...
@@ -675,45 +666,20 @@ struct UnarySquare
...
@@ -675,45 +666,20 @@ struct UnarySquare
};
};
};
};
struct
UnaryAbs
final
:
public
UnaryOpBase
struct
UnaryAbs
{
{
__host__
__device__
constexpr
UnaryAbs
()
=
default
;
template
<
typename
T
>
__host__
__device__
constexpr
UnaryAbs
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
UnaryAbs
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
~
UnaryAbs
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
{
y
=
ck
::
math
::
abs
(
x
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
}
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
y
=
ck
::
math
::
abs
(
x
);
}
}
;
template
<
>
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
...
@@ -732,41 +698,20 @@ struct UnarySqrt
...
@@ -732,41 +698,20 @@ struct UnarySqrt
};
};
};
};
struct
Relu
final
:
public
UnaryOpBase
struct
Relu
{
{
__host__
__device__
constexpr
Relu
()
=
default
;
template
<
typename
T
>
__host__
__device__
constexpr
Relu
(
const
Relu
&
)
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Relu
(
Relu
&&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
const
Relu
&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
Relu
&&
)
=
default
;
__host__
__device__
~
Relu
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
y
=
x
>
0
?
x
:
0
;
}
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
template
<
>
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
...
@@ -913,52 +858,18 @@ struct Gelu
...
@@ -913,52 +858,18 @@ struct Gelu
}
}
};
};
struct
Sigmoid
final
:
public
UnaryOpBase
struct
Sigmoid
{
{
__host__
__device__
constexpr
Sigmoid
()
=
default
;
template
<
typename
T
>
__host__
__device__
constexpr
Sigmoid
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Sigmoid
(
Sigmoid
&&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
Sigmoid
&&
)
=
default
;
__host__
__device__
~
Sigmoid
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
float
y_f32
=
one
/
(
one
+
ck
::
math
::
exp
(
x_f32
));
is_same
<
T
,
int32_t
>::
value
,
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
"Data type is not supported by this operation!"
);
}
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
};
};
};
struct
Silu
struct
Silu
...
@@ -974,44 +885,18 @@ struct Silu
...
@@ -974,44 +885,18 @@ struct Silu
};
};
};
};
struct
TanH
final
:
public
UnaryOpBase
struct
TanH
{
{
__host__
__device__
constexpr
TanH
()
=
default
;
template
<
typename
T
>
__host__
__device__
constexpr
TanH
(
const
TanH
&
)
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
TanH
(
TanH
&&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
const
TanH
&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
TanH
&&
)
=
default
;
__host__
__device__
~
TanH
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
{
y
=
ck
::
math
::
tanh
(
x
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
}
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
y
=
ck
::
math
::
tanh
(
x
);
}
}
;
};
};
struct
ACos
struct
ACos
...
@@ -1252,418 +1137,138 @@ struct Rcp
...
@@ -1252,418 +1137,138 @@ struct Rcp
};
};
};
};
struct
Swish
final
:
public
UnaryOpBase
struct
Swish
{
{
__host__
__device__
constexpr
Swish
(
const
Swish
&
)
=
default
;
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
constexpr
Swish
(
Swish
&&
)
=
default
;
__host__
__device__
~
Swish
()
=
default
;
__host__
__device__
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
double
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int32_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
half
_t
>::
value
,
is_same
<
X
,
ck
::
half_t
>::
value
||
is_same
<
X
,
int8
_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
half
_t
>::
value
,
is_same
<
Y
,
ck
::
half_t
>::
value
||
is_same
<
Y
,
int8
_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
};
const
float
beta_
;
};
};
struct
SoftRelu
final
:
public
UnaryOpBase
struct
SoftRelu
{
{
__host__
__device__
constexpr
SoftRelu
(
const
SoftRelu
&
)
=
default
;
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
SoftRelu
(
SoftRelu
&&
)
=
default
;
__host__
__device__
~
SoftRelu
()
=
default
;
__host__
__device__
SoftRelu
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
}
const
float
alpha_
;
};
};
struct
Power
final
:
public
UnaryOpBase
struct
Power
{
{
__host__
__device__
constexpr
Power
(
const
Power
&
)
=
default
;
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
__host__
__device__
constexpr
Power
(
Power
&&
)
=
default
;
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
__host__
__device__
~
Power
()
=
default
;
__host__
__device__
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
template
<
typename
T
>
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
)
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
__host__
__device__
float
get_gamma
()
const
{
return
gamma_
;
}
const
float
alpha_
;
const
float
alpha_
;
const
float
beta_
;
const
float
beta_
;
const
float
gamma_
;
const
float
gamma_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
float
casted_gamma
=
type_convert
<
float
>
(
gamma_
);
float
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
double
casted_gamma
=
type_convert
<
double
>
(
gamma_
);
double
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
int32_t
casted_gamma
=
type_convert
<
int32_t
>
(
gamma_
);
int32_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
int8_t
casted_gamma
=
type_convert
<
int8_t
>
(
gamma_
);
int8_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
half_t
casted_gamma
=
type_convert
<
half_t
>
(
gamma_
);
half_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
bhalf_t
casted_gamma
=
type_convert
<
bhalf_t
>
(
gamma_
);
bhalf_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
};
};
struct
ClippedRelu
final
:
public
UnaryOpBase
struct
ClippedRelu
{
{
__host__
__device__
constexpr
ClippedRelu
(
const
ClippedRelu
&
)
=
default
;
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
__host__
__device__
constexpr
ClippedRelu
(
ClippedRelu
&&
)
=
default
;
__host__
__device__
~
ClippedRelu
()
=
default
;
__host__
__device__
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
template
<
typename
T
>
:
alpha_
(
alpha
),
beta_
(
beta
)
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
alpha_
;
const
float
alpha_
;
const
float
beta_
;
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
};
};
struct
LeakyRelu
final
:
public
UnaryOpBase
struct
LeakyRelu
{
{
__host__
__device__
constexpr
LeakyRelu
(
const
LeakyRelu
&
)
=
default
;
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
LeakyRelu
(
LeakyRelu
&&
)
=
default
;
__host__
__device__
~
LeakyRelu
()
=
default
;
__host__
__device__
LeakyRelu
(
float
alpha
=
0.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
template
<
typename
T
>
{
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()([[
maybe_unused
]]
bhalf_t
&
y
,
[[
maybe_unused
]]
const
bhalf_t
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
}
const
float
alpha_
;
};
};
struct
Elu
final
:
public
UnaryOpBase
struct
Elu
{
{
__host__
__device__
constexpr
Elu
(
const
Elu
&
)
=
default
;
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Elu
(
Elu
&&
)
=
default
;
__host__
__device__
~
Elu
()
=
default
;
__host__
__device__
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
template
<
typename
T
>
{
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
}
const
float
alpha_
;
};
};
struct
Logistic
final
:
public
UnaryOpBase
struct
Logistic
{
{
__host__
__device__
constexpr
Logistic
(
const
Logistic
&
)
=
default
;
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Logistic
(
Logistic
&&
)
=
default
;
__host__
__device__
~
Logistic
()
=
default
;
__host__
__device__
Logistic
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
constexpr
float
one
=
type_convert
<
float
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
}
const
float
alpha_
;
};
};
struct
ConvInvscale
struct
ConvInvscale
...
@@ -1728,7 +1333,7 @@ struct ConvScaleRelu
...
@@ -1728,7 +1333,7 @@ struct ConvScaleRelu
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
{
float
x
;
float
x
;
Relu
{}(
x
,
c
*
scale_in_
*
scale_wei_
);
Relu
{}
.
template
operator
()
<
float
>
(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
};
...
@@ -1809,225 +1414,138 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
...
@@ -1809,225 +1414,138 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
struct
DynamicUnaryOp
struct
DynamicUnaryOp
{
{
DynamicUnaryOp
&
operator
=
(
const
DynamicUnaryOp
&
other
)
{
if
(
this
!=
&
other
)
{
unary_op_ptr_
=
other
.
unary_op_ptr_
;
unary_op_type_
=
other
.
unary_op_type_
;
}
return
*
this
;
}
__host__
__device__
DynamicUnaryOp
()
=
delete
;
__host__
__device__
DynamicUnaryOp
()
=
delete
;
__host__
__device__
DynamicUnaryOp
(
const
Swish
&
swish
)
__host__
__device__
DynamicUnaryOp
(
const
Swish
&
swish
)
:
unary_op_type_
(
UnaryOpType
::
Swish
),
swish_
{
swish
.
beta_
}
{
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
Swish
&&
swish
)
__host__
__device__
DynamicUnaryOp
(
const
Swish
&&
swish
)
:
unary_op_type_
(
UnaryOpType
::
Swish
),
swish_
{
swish
.
beta_
}
{
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
:
unary_op_type_
(
UnaryOpType
::
Sigmoid
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
:
unary_op_type_
(
UnaryOpType
::
Sigmoid
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&
)
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&
)
:
unary_op_type_
(
UnaryOpType
::
PassThrough
)
{
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&&
)
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&&
)
:
unary_op_type_
(
UnaryOpType
::
PassThrough
)
{
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&
logistic
)
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&
logistic
)
:
unary_op_type_
(
UnaryOpType
::
Logistic
),
logistic_
{
logistic
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&&
logistic
)
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&&
logistic
)
:
unary_op_type_
(
UnaryOpType
::
Logistic
),
logistic_
{
logistic
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
:
unary_op_type_
(
UnaryOpType
::
TanH
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
:
unary_op_type_
(
UnaryOpType
::
TanH
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
:
unary_op_type_
(
UnaryOpType
::
Relu
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
:
unary_op_type_
(
UnaryOpType
::
Relu
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&
softrelu
)
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&
softrelu
)
:
unary_op_type_
(
UnaryOpType
::
SoftRelu
),
soft_relu_
{
softrelu
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&&
softrelu
)
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&&
softrelu
)
:
unary_op_type_
(
UnaryOpType
::
SoftRelu
),
soft_relu_
{
softrelu
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
:
unary_op_type_
(
UnaryOpType
::
UnaryAbs
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
:
unary_op_type_
(
UnaryOpType
::
UnaryAbs
)
{
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&
pow
)
__host__
__device__
DynamicUnaryOp
(
const
Power
&
pow
)
:
unary_op_type_
(
UnaryOpType
::
Power
),
power_
(
pow
.
alpha_
,
pow
.
beta_
,
pow
.
gamma_
)
{
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&&
pow
)
__host__
__device__
DynamicUnaryOp
(
const
Power
&&
pow
)
:
unary_op_type_
(
UnaryOpType
::
Power
),
power_
(
pow
.
alpha_
,
pow
.
beta_
,
pow
.
gamma_
)
{
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&
clippedrelu
)
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&
clippedrelu
)
:
unary_op_type_
(
UnaryOpType
::
ClippedRelu
),
clipped_relu_
{
clippedrelu
.
alpha_
,
clippedrelu
.
beta_
}
{
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&&
clippedrelu
)
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&&
clippedrelu
)
:
unary_op_type_
(
UnaryOpType
::
ClippedRelu
),
clipped_relu_
{
clippedrelu
.
alpha_
,
clippedrelu
.
beta_
}
{
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&
leakyrelu
)
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&
leakyrelu
)
:
unary_op_type_
(
UnaryOpType
::
LeakyRelu
),
leaky_relu_
{
leakyrelu
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&&
leakyrelu
)
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&&
leakyrelu
)
:
unary_op_type_
(
UnaryOpType
::
LeakyRelu
),
leaky_relu_
{
leakyrelu
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&
elu
)
__host__
__device__
DynamicUnaryOp
(
const
Elu
&
elu
)
:
unary_op_type_
(
UnaryOpType
::
Elu
),
elu_
{
elu
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&&
elu
)
__host__
__device__
DynamicUnaryOp
(
const
Elu
&&
elu
)
:
unary_op_type_
(
UnaryOpType
::
Elu
),
elu_
{
elu
.
alpha_
}
{
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
:
unary_op_type_
(
dynamic_op
.
unary_op_type_
),
unary_op_ptr_
(
dynamic_op
.
unary_op_ptr_
),
alpha
(
dynamic_op
.
alpha
),
beta
(
dynamic_op
.
beta
),
gamma
(
dynamic_op
.
gamma
)
{
}
__host__
__device__
~
DynamicUnaryOp
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
delete
static_cast
<
Swish
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
delete
static_cast
<
Sigmoid
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
PassThrough
):
delete
static_cast
<
PassThrough
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Logistic
):
delete
static_cast
<
Logistic
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
TanH
):
delete
static_cast
<
TanH
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Relu
):
delete
static_cast
<
Relu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
delete
static_cast
<
SoftRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
delete
static_cast
<
UnaryAbs
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Power
):
delete
static_cast
<
Power
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
delete
static_cast
<
ClippedRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
delete
static_cast
<
LeakyRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Elu
):
delete
static_cast
<
Elu
*>
(
unary_op_ptr_
);
break
;
default:
break
;
}
}
}
__device__
void
InitUnaryOpPtrOnDevice
()
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
=
default
;
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
unary_op_ptr_
=
new
Swish
(
beta
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
unary_op_ptr_
=
new
Sigmoid
;
break
;
case
(
UnaryOpType
::
PassThrough
):
unary_op_ptr_
=
new
PassThrough
;
break
;
case
(
UnaryOpType
::
Logistic
):
unary_op_ptr_
=
new
Logistic
(
alpha
);
break
;
case
(
UnaryOpType
::
TanH
):
unary_op_ptr_
=
new
TanH
;
break
;
case
(
UnaryOpType
::
Relu
):
unary_op_ptr_
=
new
Relu
;
break
;
case
(
UnaryOpType
::
SoftRelu
):
unary_op_ptr_
=
new
SoftRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
unary_op_ptr_
=
new
UnaryAbs
;
break
;
case
(
UnaryOpType
::
Power
):
unary_op_ptr_
=
new
Power
(
alpha
,
beta
,
gamma
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
unary_op_ptr_
=
new
ClippedRelu
(
alpha
,
beta
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
unary_op_ptr_
=
new
LeakyRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
Elu
):
unary_op_ptr_
=
new
Elu
(
alpha
);
break
;
default:
unary_op_ptr_
=
nullptr
;
break
;
}
}
template
<
typename
Y
,
typename
X
>
__host__
__device__
~
DynamicUnaryOp
()
{}
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
unary_op_ptr_
->
operator
()(
y
,
x
);
}
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
isSupported
<
X
,
Y
>
();
switch
(
unary_op_type_
)
switch
(
unary_op_type_
)
{
{
case
(
UnaryOpType
::
Swish
):
S
wish
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Swish
):
s
wish
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
S
igmoid
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
s
igmoid
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
P
ass
T
hrough
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
p
ass
_t
hrough
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
L
ogistic
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
l
ogistic
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
T
an
H
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
t
an
h_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
S
oft
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
s
oft
_r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
U
nary
Abs
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
u
nary
_abs_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
P
ower
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
p
ower
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
C
lipped
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
c
lipped
_r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
L
eaky
R
elu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
l
eaky
_r
elu
_
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
E
lu
{}.
operator
()
(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
e
lu
_
(
y
,
x
);
break
;
default:
break
;
default:
break
;
}
}
}
}
template
<
typename
X
,
typename
Y
>
template
<
>
__
device__
__host__
constexpr
void
isSupported
(
)
const
__
host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
{
float
y_float
;
static_assert
(
std
::
is_same
<
X
,
Y
>::
value
,
"X and Y must be of the same type"
);
float
x_float
=
type_convert
<
float
>
(
x
);
this
->
operator
()(
y_float
,
x_float
);
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
y
=
type_convert
<
bhalf_t
>
(
y_float
);
is_same
<
X
,
bhalf_t
>::
value
||
is_same
<
X
,
half_t
>::
value
||
is_same
<
X
,
int32_t
>::
value
||
is_same
<
X
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
}
}
private:
private:
...
@@ -2049,12 +1567,20 @@ struct DynamicUnaryOp
...
@@ -2049,12 +1567,20 @@ struct DynamicUnaryOp
public:
public:
UnaryOpType
unary_op_type_
;
UnaryOpType
unary_op_type_
;
UnaryOpBase
*
unary_op_ptr_
=
nullptr
;
float
alpha
;
Swish
swish_
;
float
beta
;
Sigmoid
sigmoid_
;
float
gamma
;
PassThrough
pass_through_
;
Logistic
logistic_
;
TanH
tanh_
;
Relu
relu_
;
SoftRelu
soft_relu_
;
UnaryAbs
unary_abs_
;
Power
power_
;
ClippedRelu
clipped_relu_
;
LeakyRelu
leaky_relu_
;
Elu
elu_
;
};
};
#pragma clang diagnostic pop
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
dec32dc6
...
@@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
...
@@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
#if CK_
WORKAROUND_DENORM_FIX
#if CK_
GFX90A_DENORM_WORKAROUND
using
AComputeDataType
=
using
AComputeDataType
=
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
using
BComputeDataType
=
using
BComputeDataType
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
dec32dc6
...
@@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
#if CK_
WORKAROUND_DENORM_FIX
#if CK_
GFX90A_DENORM_WORKAROUND
using
AComputeDataType
=
using
AComputeDataType
=
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
using
BComputeDataType
=
using
BComputeDataType
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
dec32dc6
...
@@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
#if CK_
WORKAROUND_DENORM_FIX
#if CK_
GFX90A_DENORM_WORKAROUND
using
AComputeDataType
=
using
AComputeDataType
=
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
conditional_t
<
is_same_v
<
AComputeDataType_
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
AComputeDataType_
>
;
#else
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
dec32dc6
...
@@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// when mfma if fixed, remove this section and update
// when mfma if fixed, remove this section and update
// FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
// FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
// throughout this file
// throughout this file
#if CK_
WORKAROUND_DENORM_FIX
#if CK_
GFX90A_DENORM_WORKAROUND
using
FloatAAdjusted
=
using
FloatAAdjusted
=
conditional_t
<
is_same_v
<
ComputeTypeA
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ComputeTypeA
>
;
conditional_t
<
is_same_v
<
ComputeTypeA
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ComputeTypeA
>
;
using
FloatBAdjusted
=
using
FloatBAdjusted
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
dec32dc6
...
@@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
// FloatABAdjusted -> FloatAB throughout this file
#if CK_
WORKAROUND_DENORM_FIX
#if CK_
GFX90A_DENORM_WORKAROUND
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
#else
using
FloatABAdjusted
=
FloatAB
;
using
FloatABAdjusted
=
FloatAB
;
...
...
include/ck/utility/data_type.hpp
View file @
dec32dc6
...
@@ -19,8 +19,6 @@ struct pk_i4_t
...
@@ -19,8 +19,6 @@ struct pk_i4_t
type
data
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
operator
float
()
const
{
return
static_cast
<
int8_t
>
(
data
);
}
};
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
dec32dc6
...
@@ -29,6 +29,13 @@ struct DynamicBuffer
...
@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize
element_space_size_
;
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
T
invalid_element_value_
=
T
{
0
};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
T
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
{
...
@@ -82,14 +89,18 @@ struct DynamicBuffer
...
@@ -82,14 +89,18 @@ struct DynamicBuffer
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
t_per_x
,
coherence
>
(
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
}
else
else
{
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
t_per_x
,
coherence
>
(
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
,
invalid_element_value_
);
}
}
}
}
else
else
...
@@ -191,7 +202,7 @@ struct DynamicBuffer
...
@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf
.
p_data_
,
dst_buf
.
p_data_
,
dst_offset
,
dst_offset
,
is_valid_element
,
is_valid_element
,
element_space_size_
);
element_space_size_
/
PackedSize
);
}
}
template
<
typename
X
,
template
<
typename
X
,
...
@@ -226,7 +237,7 @@ struct DynamicBuffer
...
@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
&&
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
&&
...
@@ -378,7 +389,7 @@ struct DynamicBuffer
...
@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
}
else
else
{
{
...
@@ -417,7 +428,7 @@ struct DynamicBuffer
...
@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
}
else
if
(
is_valid_element
)
else
if
(
is_valid_element
)
{
{
...
...
include/ck/utility/type_convert.hpp
View file @
dec32dc6
...
@@ -14,6 +14,41 @@ namespace ck {
...
@@ -14,6 +14,41 @@ namespace ck {
#define __gfx94__
#define __gfx94__
#endif
#endif
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
// Nan check
if
(
x
!=
x
)
{
return
uint16_t
(
0x7FC0
);
}
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
const
uint32_t
first_bf16_mantisa_bit
=
((
u
.
int32
>>
16
)
&
1
);
constexpr
uint32_t
rounding_bias
=
uint32_t
((
1
<<
15
)
-
1
);
return
uint16_t
((
u
.
int32
+
first_bf16_mantisa_bit
+
rounding_bias
)
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
// Convert X to Y, both X and Y are non-const data types.
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
template
<
typename
Y
,
typename
X
,
typename
X
,
...
@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
...
@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
return
u
.
fp32
;
}
}
// convert fp32 to bfp16
// convert fp32 to bfp16
, round to nearest even
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
{
union
#if CK_USE_RNE_BF16_CONVERSION
{
return
bf16_convert_rtn
<
bhalf_t
>
(
x
);
float
fp32
;
#else
uint32_t
int32
;
}
u
=
{
x
};
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
#endif
}
}
// convert bfp16 to fp16 via fp32
// convert bfp16 to fp16 via fp32
...
@@ -615,60 +648,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
...
@@ -615,60 +648,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
}
}
}
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
}
// namespace ck
}
// namespace ck
include/ck_tile/core.hpp
View file @
dec32dc6
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...
@@ -53,8 +54,8 @@
...
@@ -53,8 +54,8 @@
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
0 → 100644
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace
ck_tile
{
/**
* @brief Enumeration describing static tile distribution patterns.
*
*/
enum
struct
tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked
,
/**
* @brief Warp raked pattern.
*
*/
warp_raked
,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked
,
};
struct
TileDistributionEncodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
,
tile_distribution_pattern
DistributionPattern
>
struct
TileDistributionEncodingPattern2D
:
public
TileDistributionEncodingPattern
{
};
// Thread raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
thread_raked
>
:
public
TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
static
constexpr
index_t
Y1
=
warp_size
/
X0
;
static_assert
(
X0
*
Y1
==
warp_size
,
"X0 * Y1 must cover whole wavefront!"
);
static
constexpr
index_t
Y0
=
num_warps
;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static
constexpr
index_t
Y2
=
YPerTile
/
(
Y1
*
Y0
);
// # of iters within wavefront
static_assert
(
X0
*
Y1
*
Y0
==
BlockSize
,
"X0 * warp_ys * Y0 must cover whole workgroup!"
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
{});
}
};
// Warp raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
warp_raked
>
:
public
TileDistributionEncodingPattern
{
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
static
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
static
constexpr
index_t
Y0
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y0
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
static
constexpr
index_t
Y1
=
YPerTile
/
(
Y2
*
Y0
);
// # of iters within wavefront
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
};
// Block raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
block_raked
>
:
public
TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
static
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
static
constexpr
index_t
Y1
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y1
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
static
constexpr
index_t
Y0
=
YPerTile
/
(
Y2
*
Y1
);
// # of iters
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
0
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/core/arch/arch.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,18 +12,37 @@
...
@@ -12,18 +12,37 @@
namespace
ck_tile
{
namespace
ck_tile
{
enum
struct
address_space_enum
template
<
typename
,
bool
>
struct
safe_underlying_type
;
template
<
typename
T
>
struct
safe_underlying_type
<
T
,
true
>
{
using
type
=
std
::
underlying_type_t
<
T
>
;
};
template
<
typename
T
>
struct
safe_underlying_type
<
T
,
false
>
{
using
type
=
void
;
};
template
<
typename
T
>
using
safe_underlying_type_t
=
typename
safe_underlying_type
<
T
,
std
::
is_enum
<
T
>::
value
>::
type
;
enum
struct
address_space_enum
:
std
::
uint16_t
{
{
generic
,
generic
=
0
,
global
,
global
,
lds
,
lds
,
sgpr
,
sgpr
,
vgpr
,
constant
,
vgpr
};
};
enum
struct
memory_operation_enum
enum
struct
memory_operation_enum
:
std
::
uint16_t
{
{
set
,
set
=
0
,
atomic_add
,
atomic_add
,
atomic_max
,
atomic_max
,
add
add
...
@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
...
@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif
#endif
}
}
#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)(
p
);
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#define __gfx9__
#endif
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...
@@ -230,3 +230,15 @@
...
@@ -230,3 +230,15 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
include/ck_tile/core/container/tuple.hpp
View file @
dec32dc6
...
@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
...
@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
using
Idx
=
number
<
tuple
<
Ts
...
>::
size
()
-
i
-
1
>
;
using
Idx
=
number
<
tuple
<
Ts
...
>::
size
()
-
i
-
1
>
;
return
t
.
at
(
Idx
{});
return
t
.
at
(
Idx
{});
},
},
number
<
tuple
<
Ts
...
>::
size
()
()
>
{});
number
<
tuple
<
Ts
...
>::
size
()
>
{});
}
}
// Reduce tuple values in specific range using Function
// Reduce tuple values in specific range using Function
...
...
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment