Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4d914af3
Commit
4d914af3
authored
Oct 31, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
223a2abe
4b798833
Changes
333
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2542 additions
and
261 deletions
+2542
-261
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+5
-0
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+38
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+16
-3
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+24
-4
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+797
-152
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+6
-6
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+9
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+5
-5
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+6
-0
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+104
-0
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+60
-0
include/ck_tile/core/algorithm/space_filling_curve.hpp
include/ck_tile/core/algorithm/space_filling_curve.hpp
+7
-5
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+195
-18
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+43
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+20
-0
include/ck_tile/core/container/sequence.hpp
include/ck_tile/core/container/sequence.hpp
+122
-0
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+49
-5
include/ck_tile/core/numeric/int8.hpp
include/ck_tile/core/numeric/int8.hpp
+104
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+930
-44
No files found.
example/ck_tile/CMakeLists.txt
View file @
4d914af3
...
@@ -6,3 +6,8 @@ add_subdirectory(01_fmha)
...
@@ -6,3 +6,8 @@ add_subdirectory(01_fmha)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
add_subdirectory
(
04_img2col
)
add_subdirectory
(
05_reduce
)
add_subdirectory
(
06_permute
)
add_subdirectory
(
09_topk_softmax
)
add_subdirectory
(
10_rmsnorm2d
)
add_subdirectory
(
11_add_rmsnorm2d_rdquant
)
include/ck/host_utility/flush_cache.hpp
View file @
4d914af3
...
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
#define MEDIAN
1
#define MEDIAN
0
if
(
stream_config
.
time_kernel_
)
if
(
stream_config
.
time_kernel_
)
{
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
...
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else
#else
float
total_time
=
0
;
float
total_time
=
0
;
#endif
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
if
constexpr
(
!
TimePreprocess
)
if
constexpr
(
!
TimePreprocess
)
...
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess
();
preprocess
();
}
}
hipEvent_t
start
,
stop
;
//
hipEvent_t start, stop;
hip_check_error
(
hipEventCreate
(
&
start
));
//
hip_check_error(hipEventCreate(&start));
hip_check_error
(
hipEventCreate
(
&
stop
));
//
hip_check_error(hipEventCreate(&stop));
hip_check_error
(
hipDeviceSynchronize
());
//
hip_check_error(hipDeviceSynchronize());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
//
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
if
constexpr
(
TimePreprocess
)
{
{
...
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
// end real kernel
// end real kernel
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
//
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error
(
hipEventSynchronize
(
stop
));
//
hip_check_error(hipEventSynchronize(stop));
float
cur_time
=
0
;
//
float cur_time = 0;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
//
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
//
#if MEDIAN
times
.
insert
(
cur_time
);
//
times.insert(cur_time);
#else
//
#else
total_time
+=
cur_time
;
//
total_time += cur_time;
#endif
//
#endif
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
//
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
}
}
}
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
cur_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
#if MEDIAN
times
.
insert
(
cur_time
);
#else
total_time
+=
cur_time
;
#endif
#if MEDIAN
#if MEDIAN
auto
mid
=
times
.
begin
();
auto
mid
=
times
.
begin
();
...
@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
(
*
mid
+
*
mid_next
)
/
2
;
return
(
*
mid
+
*
mid_next
)
/
2
;
}
}
#else
#else
return
total_time
/
nrepeat
;
// return total_time / nrepeat;
hipDeviceProp_t
deviceProps
;
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
float
preprocess_offset
=
deviceProps
.
multiProcessorCount
==
80
?
0.005
:
0.01
;
return
(
total_time
-
preprocess_offset
*
nrepeat
)
/
nrepeat
;
#endif
#endif
}
}
else
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
4d914af3
...
@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA
...
@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
4d914af3
...
@@ -85,9 +85,9 @@ __global__ void
...
@@ -85,9 +85,9 @@ __global__ void
BsPointer
p_bs_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -121,6 +121,19 @@ __global__ void
...
@@ -121,6 +121,19 @@ __global__ void
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_group_offset
[
i
];
});
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
a_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
b_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
is_same_v
<
CDEElementwiseOperation
,
element_wise
::
DynamicUnaryOp
>
)
{
cde_element_op
.
InitUnaryOpPtrOnDevice
();
}
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
{
{
AsPointer
p_as_grid_grp
;
AsPointer
p_as_grid_grp
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
4d914af3
...
@@ -272,6 +272,26 @@ struct MultiplyMultiply
...
@@ -272,6 +272,26 @@ struct MultiplyMultiply
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
ck
::
half_t
,
ck
::
half_t
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
ck
::
half_t
&
d0
,
const
ck
::
half_t
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
bhalf_t
,
int
,
float
,
float
>
(
ck
::
bhalf_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
bhalf_t
>
(
x0_f
);
}
};
};
struct
MultiplyAddFastGelu
struct
MultiplyAddFastGelu
...
@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu
...
@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu
const
float
&
d1
)
const
const
float
&
d1
)
const
{
{
const
float
x
=
c
*
alpha1_
+
alpha2_
*
d0
+
d1
;
const
float
x
=
c
*
alpha1_
+
alpha2_
*
d0
+
d1
;
Relu
{}.
template
operator
()
<
float
>(
e
,
x
)
;
e
=
x
>
0
?
x
:
0
;
}
}
template
<
>
template
<
>
...
@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu
...
@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu
type_convert
<
float
>
(
d1
);
type_convert
<
float
>
(
d1
);
float
result
=
0
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
half_t
>
(
result
);
e
=
type_convert
<
half_t
>
(
result
);
}
}
...
@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu
...
@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu
type_convert
<
float
>
(
d1
);
type_convert
<
float
>
(
d1
);
float
result
=
0
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
bhalf_t
>
(
result
);
e
=
type_convert
<
bhalf_t
>
(
result
);
}
}
...
@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu
...
@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
d0
+
d1
;
const
float
x
=
type_convert
<
float
>
(
c
)
*
alpha1_
+
alpha2_
*
d0
+
d1
;
float
result
=
0
;
float
result
=
0
;
Relu
{}.
template
operator
()
<
float
>(
result
,
x
)
;
result
=
x
>
0
?
x
:
0
;
e
=
type_convert
<
int8_t
>
(
result
);
e
=
type_convert
<
int8_t
>
(
result
);
}
}
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
4d914af3
...
@@ -7,11 +7,38 @@
...
@@ -7,11 +7,38 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
#include <cassert>
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
{
public:
__host__
__device__
~
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
()
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
constexpr
UnaryOpBase
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
const
UnaryOpBase
&
)
=
default
;
__host__
__device__
UnaryOpBase
&
operator
=
(
UnaryOpBase
&&
)
=
default
;
__host__
__device__
virtual
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
=
0
;
__host__
__device__
virtual
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
=
0
;
};
struct
PassThroughPack2
struct
PassThroughPack2
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
@@ -25,17 +52,30 @@ struct PassThroughPack2
...
@@ -25,17 +52,30 @@ struct PassThroughPack2
constexpr
const
static
bool
is_pack2_invocable
=
true
;
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
};
struct
PassThrough
struct
PassThrough
final
:
public
UnaryOpBase
{
{
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
PassThrough
&
)
=
default
;
__host__
__device__
constexpr
PassThrough
(
PassThrough
&&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
const
PassThrough
&
)
=
default
;
__host__
__device__
PassThrough
&
operator
=
(
PassThrough
&&
)
=
default
;
__host__
__device__
~
PassThrough
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
x
;
}
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
{
...
@@ -48,36 +88,12 @@ struct PassThrough
...
@@ -48,36 +88,12 @@ struct PassThrough
y
=
type_convert
<
double
>
(
x
);
y
=
type_convert
<
double
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
{
y
=
type_convert
<
half_t
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
{
...
@@ -102,12 +118,6 @@ struct PassThrough
...
@@ -102,12 +118,6 @@ struct PassThrough
y
=
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
{
{
...
@@ -407,20 +417,45 @@ struct UnarySquare
...
@@ -407,20 +417,45 @@ struct UnarySquare
};
};
};
};
struct
UnaryAbs
struct
UnaryAbs
final
:
public
UnaryOpBase
{
{
template
<
typename
T
>
__host__
__device__
constexpr
UnaryAbs
()
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
UnaryAbs
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
constexpr
UnaryAbs
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
const
UnaryAbs
&
)
=
default
;
__host__
__device__
UnaryAbs
&
operator
=
(
UnaryAbs
&&
)
=
default
;
__host__
__device__
~
UnaryAbs
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
y
=
ck
::
math
::
abs
(
x
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
}
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
y
=
ck
::
math
::
abs
(
x
);
};
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
abs
(
x
);
}
template
<
>
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
y
=
ck
::
type_convert
<
f8_t
>
(
ck
::
math
::
abs
(
ck
::
type_convert
<
float
>
(
x
)));
...
@@ -439,20 +474,41 @@ struct UnarySqrt
...
@@ -439,20 +474,41 @@ struct UnarySqrt
};
};
};
};
struct
Relu
struct
Relu
final
:
public
UnaryOpBase
{
{
template
<
typename
T
>
__host__
__device__
constexpr
Relu
()
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Relu
(
const
Relu
&
)
=
default
;
__host__
__device__
constexpr
Relu
(
Relu
&&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
const
Relu
&
)
=
default
;
__host__
__device__
Relu
&
operator
=
(
Relu
&&
)
=
default
;
__host__
__device__
~
Relu
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
y
=
x
>
0
?
x
:
0
;
}
}
template
<
>
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
__host__
__device__
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
x
>
0
?
x
:
0
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
{
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
...
@@ -599,18 +655,52 @@ struct Gelu
...
@@ -599,18 +655,52 @@ struct Gelu
}
}
};
};
struct
Sigmoid
struct
Sigmoid
final
:
public
UnaryOpBase
{
{
template
<
typename
T
>
__host__
__device__
constexpr
Sigmoid
()
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
Sigmoid
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
constexpr
Sigmoid
(
Sigmoid
&&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
const
Sigmoid
&
)
=
default
;
__host__
__device__
Sigmoid
&
operator
=
(
Sigmoid
&&
)
=
default
;
__host__
__device__
~
Sigmoid
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
constexpr
float
one
=
type_convert
<
float
>
(
1
);
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
is_same
<
T
,
int32_t
>::
value
,
}
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
{
};
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
one
/
(
one
+
ck
::
math
::
exp
(
-
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
constexpr
float
one
=
type_convert
<
float
>
(
1
);
float
x_f32
=
ck
::
type_convert
<
float
>
(
x
);
float
y_f32
=
one
/
(
one
+
ck
::
math
::
exp
(
x_f32
));
y
=
ck
::
type_convert
<
bhalf_t
>
(
y_f32
);
}
};
};
struct
Silu
struct
Silu
...
@@ -626,18 +716,44 @@ struct Silu
...
@@ -626,18 +716,44 @@ struct Silu
};
};
};
};
struct
TanH
struct
TanH
final
:
public
UnaryOpBase
{
{
template
<
typename
T
>
__host__
__device__
constexpr
TanH
()
=
default
;
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
constexpr
TanH
(
const
TanH
&
)
=
default
;
__host__
__device__
constexpr
TanH
(
TanH
&&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
const
TanH
&
)
=
default
;
__host__
__device__
TanH
&
operator
=
(
TanH
&&
)
=
default
;
__host__
__device__
~
TanH
()
=
default
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
y
=
ck
::
math
::
tanh
(
x
);
is_same
<
T
,
ck
::
half_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
}
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
y
=
ck
::
math
::
tanh
(
x
);
};
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
y
=
ck
::
math
::
tanh
(
x
);
}
};
};
struct
ACos
struct
ACos
...
@@ -878,138 +994,418 @@ struct Rcp
...
@@ -878,138 +994,418 @@ struct Rcp
};
};
};
};
struct
Swish
struct
Swish
final
:
public
UnaryOpBase
{
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
constexpr
Swish
(
const
Swish
&
)
=
default
;
__host__
__device__
constexpr
Swish
(
Swish
&&
)
=
default
;
__host__
__device__
~
Swish
()
=
default
;
__host__
__device__
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
float
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
double
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int32_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
half_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
}
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
ck
::
half_t
>::
value
,
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
ck
::
half_t
>::
value
,
is_same
<
Y
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
}
const
float
beta_
;
};
};
struct
SoftRelu
struct
SoftRelu
final
:
public
UnaryOpBase
{
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
SoftRelu
(
const
SoftRelu
&
)
=
default
;
__host__
__device__
constexpr
SoftRelu
(
SoftRelu
&&
)
=
default
;
__host__
__device__
~
SoftRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
SoftRelu
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
constexpr
float
one
=
type_convert
<
float
>
(
1
);
is_same
<
T
,
int8_t
>::
value
,
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
"Data type is not supported by this operation!"
);
}
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
ck
::
math
::
log
(
one
+
ck
::
math
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
}
const
float
alpha_
;
};
};
struct
Power
struct
Power
final
:
public
UnaryOpBase
{
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
__host__
__device__
constexpr
Power
(
const
Power
&
)
=
default
;
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
__host__
__device__
constexpr
Power
(
Power
&&
)
=
default
;
__host__
__device__
~
Power
()
=
default
;
template
<
typename
T
>
__host__
__device__
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
)
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
__host__
__device__
float
get_gamma
()
const
{
return
gamma_
;
}
const
float
alpha_
;
const
float
alpha_
;
const
float
beta_
;
const
float
beta_
;
const
float
gamma_
;
const
float
gamma_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
float
casted_gamma
=
type_convert
<
float
>
(
gamma_
);
float
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
double
casted_gamma
=
type_convert
<
double
>
(
gamma_
);
double
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
int32_t
casted_gamma
=
type_convert
<
int32_t
>
(
gamma_
);
int32_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
int8_t
casted_gamma
=
type_convert
<
int8_t
>
(
gamma_
);
int8_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
half_t
casted_gamma
=
type_convert
<
half_t
>
(
gamma_
);
half_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
bhalf_t
casted_gamma
=
type_convert
<
bhalf_t
>
(
gamma_
);
bhalf_t
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck
::
math
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
};
};
struct
ClippedRelu
struct
ClippedRelu
final
:
public
UnaryOpBase
{
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
__host__
__device__
constexpr
ClippedRelu
(
const
ClippedRelu
&
)
=
default
;
__host__
__device__
constexpr
ClippedRelu
(
ClippedRelu
&&
)
=
default
;
__host__
__device__
~
ClippedRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
:
alpha_
(
alpha
),
beta_
(
beta
)
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
}
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
__host__
__device__
float
get_beta
()
const
{
return
beta_
;
}
const
float
alpha_
;
const
float
alpha_
;
const
float
beta_
;
const
float
beta_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
float
casted_beta
=
type_convert
<
float
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
double
casted_beta
=
type_convert
<
double
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
int32_t
casted_beta
=
type_convert
<
int32_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
int8_t
casted_beta
=
type_convert
<
int8_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
half_t
casted_beta
=
type_convert
<
half_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
bhalf_t
casted_beta
=
type_convert
<
bhalf_t
>
(
beta_
);
y
=
ck
::
math
::
min
(
casted_beta
,
ck
::
math
::
max
(
casted_alpha
,
x
));
}
};
};
struct
LeakyRelu
struct
LeakyRelu
final
:
public
UnaryOpBase
{
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
LeakyRelu
(
const
LeakyRelu
&
)
=
default
;
__host__
__device__
constexpr
LeakyRelu
(
LeakyRelu
&&
)
=
default
;
__host__
__device__
~
LeakyRelu
()
=
default
;
template
<
typename
T
>
__host__
__device__
LeakyRelu
(
float
alpha
=
0.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
__host__
__device__
inline
void
operator
()([[
maybe_unused
]]
bhalf_t
&
y
,
[[
maybe_unused
]]
const
bhalf_t
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
}
const
float
alpha_
;
};
};
struct
Elu
struct
Elu
final
:
public
UnaryOpBase
{
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Elu
(
const
Elu
&
)
=
default
;
__host__
__device__
constexpr
Elu
(
Elu
&&
)
=
default
;
__host__
__device__
~
Elu
()
=
default
;
template
<
typename
T
>
__host__
__device__
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
is_same
<
T
,
int8_t
>::
value
,
}
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck
::
math
::
expm1
(
x
);
}
}
const
float
alpha_
;
};
};
struct
Logistic
struct
Logistic
final
:
public
UnaryOpBase
{
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
__host__
__device__
constexpr
Logistic
(
const
Logistic
&
)
=
default
;
__host__
__device__
constexpr
Logistic
(
Logistic
&&
)
=
default
;
__host__
__device__
~
Logistic
()
=
default
;
template
<
typename
T
>
__host__
__device__
Logistic
(
float
alpha
=
1.0
f
)
:
alpha_
(
alpha
)
{}
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
float
get_alpha
()
const
{
return
alpha_
;
}
const
float
alpha_
;
__host__
__device__
inline
void
operator
()(
float
&
y
,
const
float
&
x
)
const
final
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
float
casted_alpha
=
type_convert
<
float
>
(
alpha_
);
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
constexpr
float
one
=
type_convert
<
float
>
(
1
);
is_same
<
T
,
int8_t
>::
value
,
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
"Data type is not supported by this operation!"
);
}
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
__host__
__device__
inline
void
operator
()(
double
&
y
,
const
double
&
x
)
const
final
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
{
double
casted_alpha
=
type_convert
<
double
>
(
alpha_
);
constexpr
double
one
=
type_convert
<
double
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
)
const
final
{
int32_t
casted_alpha
=
type_convert
<
int32_t
>
(
alpha_
);
constexpr
int32_t
one
=
type_convert
<
int32_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
int8_t
&
y
,
const
int8_t
&
x
)
const
final
{
int8_t
casted_alpha
=
type_convert
<
int8_t
>
(
alpha_
);
constexpr
int8_t
one
=
type_convert
<
int8_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
half_t
&
y
,
const
half_t
&
x
)
const
final
{
half_t
casted_alpha
=
type_convert
<
half_t
>
(
alpha_
);
constexpr
half_t
one
=
type_convert
<
half_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
__host__
__device__
inline
void
operator
()(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
final
{
bhalf_t
casted_alpha
=
type_convert
<
bhalf_t
>
(
alpha_
);
constexpr
bhalf_t
one
=
type_convert
<
bhalf_t
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
}
const
float
alpha_
;
};
};
struct
ConvInvscale
struct
ConvInvscale
...
@@ -1074,7 +1470,7 @@ struct ConvScaleRelu
...
@@ -1074,7 +1470,7 @@ struct ConvScaleRelu
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
{
float
x
;
float
x
;
Relu
{}
.
template
operator
()
<
float
>
(
x
,
c
*
scale_in_
*
scale_wei_
);
Relu
{}(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
e
=
type_convert
<
f8_t
>
(
x
*
scale_out_
);
};
};
...
@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
...
@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
__device__
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
};
struct
DynamicUnaryOp
{
DynamicUnaryOp
&
operator
=
(
const
DynamicUnaryOp
&
other
)
{
if
(
this
!=
&
other
)
{
unary_op_ptr_
=
other
.
unary_op_ptr_
;
unary_op_type_
=
other
.
unary_op_type_
;
}
return
*
this
;
}
__host__
__device__
DynamicUnaryOp
()
=
delete
;
__host__
__device__
DynamicUnaryOp
(
const
Swish
&
swish
)
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Swish
&&
swish
)
{
unary_op_type_
=
UnaryOpType
::
Swish
;
beta
=
swish
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
Sigmoid
&&
)
{
unary_op_type_
=
UnaryOpType
::
Sigmoid
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
PassThrough
&&
)
{
unary_op_type_
=
UnaryOpType
::
PassThrough
;
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&
logistic
)
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Logistic
&&
logistic
)
{
unary_op_type_
=
UnaryOpType
::
Logistic
;
alpha
=
logistic
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
TanH
&&
)
{
unary_op_type_
=
UnaryOpType
::
TanH
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
Relu
&&
)
{
unary_op_type_
=
UnaryOpType
::
Relu
;
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&
softrelu
)
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
SoftRelu
&&
softrelu
)
{
unary_op_type_
=
UnaryOpType
::
SoftRelu
;
alpha
=
softrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
UnaryAbs
&&
)
{
unary_op_type_
=
UnaryOpType
::
UnaryAbs
;
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&
pow
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
Power
&&
pow
)
{
unary_op_type_
=
UnaryOpType
::
Power
;
alpha
=
pow
.
get_alpha
();
beta
=
pow
.
get_beta
();
gamma
=
pow
.
get_gamma
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&
clippedrelu
)
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
ClippedRelu
&&
clippedrelu
)
{
unary_op_type_
=
UnaryOpType
::
ClippedRelu
;
alpha
=
clippedrelu
.
get_alpha
();
beta
=
clippedrelu
.
get_beta
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&
leakyrelu
)
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
LeakyRelu
&&
leakyrelu
)
{
unary_op_type_
=
UnaryOpType
::
LeakyRelu
;
alpha
=
leakyrelu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&
elu
)
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
Elu
&&
elu
)
{
unary_op_type_
=
UnaryOpType
::
Elu
;
alpha
=
elu
.
get_alpha
();
}
__host__
__device__
DynamicUnaryOp
(
const
DynamicUnaryOp
&
dynamic_op
)
:
unary_op_type_
(
dynamic_op
.
unary_op_type_
),
unary_op_ptr_
(
dynamic_op
.
unary_op_ptr_
),
alpha
(
dynamic_op
.
alpha
),
beta
(
dynamic_op
.
beta
),
gamma
(
dynamic_op
.
gamma
)
{
}
__host__
__device__
~
DynamicUnaryOp
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
delete
static_cast
<
Swish
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
delete
static_cast
<
Sigmoid
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
PassThrough
):
delete
static_cast
<
PassThrough
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Logistic
):
delete
static_cast
<
Logistic
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
TanH
):
delete
static_cast
<
TanH
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Relu
):
delete
static_cast
<
Relu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
delete
static_cast
<
SoftRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
delete
static_cast
<
UnaryAbs
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Power
):
delete
static_cast
<
Power
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
delete
static_cast
<
ClippedRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
delete
static_cast
<
LeakyRelu
*>
(
unary_op_ptr_
);
break
;
case
(
UnaryOpType
::
Elu
):
delete
static_cast
<
Elu
*>
(
unary_op_ptr_
);
break
;
default:
break
;
}
}
__device__
void
InitUnaryOpPtrOnDevice
()
{
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
unary_op_ptr_
=
new
Swish
(
beta
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
unary_op_ptr_
=
new
Sigmoid
;
break
;
case
(
UnaryOpType
::
PassThrough
):
unary_op_ptr_
=
new
PassThrough
;
break
;
case
(
UnaryOpType
::
Logistic
):
unary_op_ptr_
=
new
Logistic
(
alpha
);
break
;
case
(
UnaryOpType
::
TanH
):
unary_op_ptr_
=
new
TanH
;
break
;
case
(
UnaryOpType
::
Relu
):
unary_op_ptr_
=
new
Relu
;
break
;
case
(
UnaryOpType
::
SoftRelu
):
unary_op_ptr_
=
new
SoftRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
unary_op_ptr_
=
new
UnaryAbs
;
break
;
case
(
UnaryOpType
::
Power
):
unary_op_ptr_
=
new
Power
(
alpha
,
beta
,
gamma
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
unary_op_ptr_
=
new
ClippedRelu
(
alpha
,
beta
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
unary_op_ptr_
=
new
LeakyRelu
(
alpha
);
break
;
case
(
UnaryOpType
::
Elu
):
unary_op_ptr_
=
new
Elu
(
alpha
);
break
;
default:
unary_op_ptr_
=
nullptr
;
break
;
}
}
template
<
typename
Y
,
typename
X
>
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
unary_op_ptr_
->
operator
()(
y
,
x
);
}
template
<
typename
Y
,
typename
X
>
__host__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
isSupported
<
X
,
Y
>
();
switch
(
unary_op_type_
)
{
case
(
UnaryOpType
::
Swish
):
Swish
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Sigmoid
):
Sigmoid
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
PassThrough
):
PassThrough
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Logistic
):
Logistic
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
TanH
):
TanH
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Relu
):
Relu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
SoftRelu
):
SoftRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
UnaryAbs
):
UnaryAbs
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Power
):
Power
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
ClippedRelu
):
ClippedRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
LeakyRelu
):
LeakyRelu
{}.
operator
()(
y
,
x
);
break
;
case
(
UnaryOpType
::
Elu
):
Elu
{}.
operator
()(
y
,
x
);
break
;
default:
break
;
}
}
template
<
typename
X
,
typename
Y
>
__device__
__host__
constexpr
void
isSupported
()
const
{
static_assert
(
std
::
is_same
<
X
,
Y
>::
value
,
"X and Y must be of the same type"
);
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
bhalf_t
>::
value
||
is_same
<
X
,
half_t
>::
value
||
is_same
<
X
,
int32_t
>::
value
||
is_same
<
X
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
}
private:
enum
class
UnaryOpType
{
Swish
,
Sigmoid
,
PassThrough
,
Logistic
,
TanH
,
Relu
,
SoftRelu
,
UnaryAbs
,
Power
,
ClippedRelu
,
LeakyRelu
,
Elu
};
public:
UnaryOpType
unary_op_type_
;
UnaryOpBase
*
unary_op_ptr_
=
nullptr
;
float
alpha
;
float
beta
;
float
gamma
;
};
#pragma clang diagnostic pop
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/utility/amd_xdlops.hpp
View file @
4d914af3
...
@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
...
@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x32i8
(
bit_cast
<
int64_t
>
(
reg_a
),
__builtin_amdgcn_mfma_i32_16x16x32
_
i8
(
bit_cast
<
int64_t
>
(
reg_a
),
bit_cast
<
int64_t
>
(
reg_b
),
bit_cast
<
int64_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
,
0
,
0
);
0
);
}
}
};
};
...
...
include/ck/utility/data_type.hpp
View file @
4d914af3
...
@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t>
...
@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t>
static
constexpr
int
bias
=
16
;
// negative zero nan mode
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
// static constexpr int bias = 15; // ieee mode
};
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
}
// namespace ck
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
4d914af3
...
@@ -653,7 +653,7 @@ inline __device__ double sin<double>(double x)
...
@@ -653,7 +653,7 @@ inline __device__ double sin<double>(double x)
template
<
>
template
<
>
inline
__device__
half_t
sin
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
sin
<
half_t
>
(
half_t
x
)
{
{
return
::
hsin
(
x
);
return
hsin
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -785,7 +785,7 @@ inline __device__ double ceil<double>(double x)
...
@@ -785,7 +785,7 @@ inline __device__ double ceil<double>(double x)
template
<
>
template
<
>
inline
__device__
half_t
ceil
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
ceil
<
half_t
>
(
half_t
x
)
{
{
return
::
hceil
(
x
);
return
hceil
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -827,7 +827,7 @@ inline __device__ double floor<double>(double x)
...
@@ -827,7 +827,7 @@ inline __device__ double floor<double>(double x)
template
<
>
template
<
>
inline
__device__
half_t
floor
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
floor
<
half_t
>
(
half_t
x
)
{
{
return
::
hfloor
(
x
);
return
hfloor
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -849,7 +849,7 @@ inline __device__ T exp(T x)
...
@@ -849,7 +849,7 @@ inline __device__ T exp(T x)
template
<
>
template
<
>
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
exp
<
half_t
>
(
half_t
x
)
{
{
return
hexp
(
x
);
return
hexp
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
>
template
<
>
...
@@ -873,7 +873,7 @@ inline __device__ T log(T x)
...
@@ -873,7 +873,7 @@ inline __device__ T log(T x)
template
<
>
template
<
>
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
log
<
half_t
>
(
half_t
x
)
{
{
return
hlog
(
x
);
return
hlog
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
>
template
<
>
...
...
include/ck_tile/core.hpp
View file @
4d914af3
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
...
@@ -24,6 +25,7 @@
...
@@ -24,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/math.hpp"
...
@@ -49,13 +51,17 @@
...
@@ -49,13 +51,17 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
4d914af3
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate
,
replicate
,
xor_t
,
xor_t
,
offset
,
offset
,
indexing
,
};
};
template
<
index_t
NDimLow
,
index_t
NDimUp
>
template
<
index_t
NDimLow
,
index_t
NDimUp
>
...
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
...
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
}
}
};
};
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
struct
indexing
:
public
base_transform
<
1
,
1
>
{
static
constexpr
index_t
NDimUp
=
1
;
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
IndexingAdaptor
iadaptor_
;
CK_TILE_HOST_DEVICE
constexpr
indexing
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing
(
const
UpLength
&
up_length
,
const
IndexingAdaptor
&
iadaptor
)
:
up_lengths_
{
make_tuple
(
up_length
)},
iadaptor_
{
iadaptor
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
indexing
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
calculate_lower_index
(
idx_low
,
idx_up
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
IndexingAdaptor
::
is_known_at_compile_time
();
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"embed{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
}
};
//*******************************************************************************************************
//*******************************************************************************************************
template
<
typename
LowLength
>
template
<
typename
LowLength
>
...
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
...
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
}
}
}
// namespace ck_tile
}
// namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace
ck_tile
{
template
<
typename
UpLength
,
typename
Indices
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform
(
const
UpLength
&
up_lengths
,
const
Indices
&
indices
)
{
// by default we use the simplest one
return
indexing
<
UpLength
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>>
{
up_lengths
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>
{
indices
}};
}
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform_with_adaptor
(
const
UpLength
&
up_lengths
,
const
IndexingAdaptor
&
iadaptor
)
{
return
indexing
<
UpLength
,
IndexingAdaptor
>
{
up_lengths
,
iadaptor
};
}
}
// namespace ck_tile
include/ck_tile/core/algorithm/indexing_adaptor.hpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template
<
typename
IndexingType
>
struct
indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
(
const
IndexingType
&
idx
)
:
cached_idx_
(
idx
)
{
}
IndexingType
cached_idx_
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
/*idx_up*/
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
cached_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*idx_low*/
,
const
UpIdx
&
/*idx_up*/
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
}
// namespace ck_tile
include/ck_tile/core/algorithm/space_filling_curve.hpp
View file @
4d914af3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -81,8 +81,10 @@ struct space_filling_curve
...
@@ -81,8 +81,10 @@ struct space_filling_curve
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
-
1
>
{});
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
-
1
>
{});
}
}
// Do not use this function directly!
// TODO: can refactor into generic lambda in the future
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
Index
get_index
(
number
<
AccessIdx1d
>
)
static
CK_TILE_HOST_DEVICE
constexpr
Index
_
get_index
(
number
<
AccessIdx1d
>
)
{
{
#if 0
#if 0
/*
/*
...
@@ -153,11 +155,11 @@ struct space_filling_curve
...
@@ -153,11 +155,11 @@ struct space_filling_curve
return
idx_md
;
return
idx_md
;
}
}
// FIXME: re
name this function
// FIXME: re
turn tuple of number<>, which is compile time only variable
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index
_tuple_of_number
(
number
<
AccessIdx1d
>
)
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index
(
number
<
AccessIdx1d
>
)
{
{
constexpr
auto
idx
=
get_index
(
number
<
AccessIdx1d
>
{});
constexpr
auto
idx
=
_
get_index
(
number
<
AccessIdx1d
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
number
<
idx
[
i
]
>
{};
},
number
<
nDim
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
number
<
idx
[
i
]
>
{};
},
number
<
nDim
>
{});
}
}
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
4d914af3
...
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
...
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
namespace
impl
{
// below type indicate the data type used for buffer load inline asm
// clang-format off
template
<
index_t
N
,
typename
T
>
struct
smem_load_trait
;
template
<
typename
T
>
struct
smem_load_trait
<
16
,
T
>
{
using
payload_t
=
fp32x4_t
;
};
template
<
typename
T
>
struct
smem_load_trait
<
8
,
T
>
{
using
payload_t
=
fp32x2_t
;
};
template
<
typename
T
>
struct
smem_load_trait
<
4
,
T
>
{
using
payload_t
=
float
;
};
template
<
typename
T
>
struct
smem_load_trait
<
2
,
T
>
{
using
payload_t
=
float
;
};
template
<
typename
T
>
struct
smem_load_trait
<
1
,
T
>
{
using
payload_t
=
float
;
};
// clang-format on
}
// namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template
<
index_t
>
struct
smem_load
;
template
<
>
struct
smem_load
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
16
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b128 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
8
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b64 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
4
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b32 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_u16 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_u8 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
// clang-format off
// clang-format off
namespace
impl
{
namespace
impl
{
...
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
...
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
bool
pre_nop
=
false
>
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
...
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t
src_wave_buffer_resource
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_linear_addr_offset
,
index_t
flag
=
0
,
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
...
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
0
,
src_linear_addr_offset
,
flag
,
flag
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
0
,
src_linear_addr_offset
,
flag
,
flag
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
...
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_async_buffer_load
(
CK_TILE_LDS_ADDR
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
,
index_t
flag
=
0
,
bool_constant
<
oob_conditional_check
>
=
{})
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
if
constexpr
(
oob_conditional_check
)
{
index_t
v_offset
=
flag
?
v_offset
:
src_wave_buffer_resource
[
2
];
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
sizeof
(
uint32_t
),
v_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
{
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
sizeof
(
uint32_t
),
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
index_t
N
,
template
<
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
...
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
...
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t
dst_wave_buffer_resource
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
dst_linear_addr_offset
,
index_t
is_valid_element
=
1
)
index_t
is_valid_element
=
1
)
{
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
...
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
...
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
0
,
dst_linear_addr_offset
,
is_valid_element
);
is_valid_element
);
}
}
else
else
...
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
...
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
0
);
dst_linear_addr_offset
);
}
}
}
}
...
@@ -2014,6 +2156,7 @@ template <typename T,
...
@@ -2014,6 +2156,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
T
*
p_src_wave
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
src_element_space_size
,
index_t
src_element_space_size
,
index_t
is_valid_element
=
0
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
...
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
...
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
dst
,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
0
,
0
,
src_linear_addr_offset
,
is_valid_element
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -2041,16 +2186,19 @@ template <typename T,
...
@@ -2041,16 +2186,19 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
int32x4_t
src_wave_buffer_resource
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
is_valid_element
=
0
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
dst
,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
0
,
0
,
src_linear_addr_offset
,
is_valid_element
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -2066,6 +2214,7 @@ template <typename T,
...
@@ -2066,6 +2214,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
T
*
p_src_wave
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
src_element_space_size
,
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
...
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
...
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
}
// This version support buffer resource as input arg
// This version support buffer resource as input arg
...
@@ -2086,12 +2240,42 @@ template <typename T,
...
@@ -2086,12 +2240,42 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
CK_TILE_LDS_ADDR
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
}
// buffer_store requires:
// buffer_store requires:
...
@@ -2146,6 +2330,7 @@ template <typename T,
...
@@ -2146,6 +2330,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
const
index_t
dst_element_space_size
)
{
{
...
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
...
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
0
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
dst_thread_element_valid
);
}
}
...
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
...
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
#endif
}
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
typename
T
,
index_t
NumElemsPerThread
>
template
<
typename
T
,
index_t
NumElemsPerThread
>
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
const
index_t
global_offset
,
...
...
include/ck_tile/core/arch/utility.hpp
View file @
4d914af3
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
#endif
}
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle
(
const
T
&
v_local
,
uint32_t
src_lane
)
{
#if 0
return __shfl(v_local, src_lane);
#elif
1
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
4d914af3
...
@@ -32,13 +32,28 @@
...
@@ -32,13 +32,28 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#else
#define CK_TILE_HOST inline
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#endif
...
@@ -203,3 +218,8 @@
...
@@ -203,3 +218,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
#endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
include/ck_tile/core/container/sequence.hpp
View file @
4d914af3
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
namespace
impl
{
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
}
// namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
impl
::
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
constexpr
auto
r
=
reverse_slice_sequence
(
Seq
{}.
reverse
(),
number
<
SliceSize
>
{},
Mask
{}.
reverse
());
return
make_tuple
(
r
[
number
<
0
>
{}].
reverse
(),
r
[
number
<
1
>
{}].
reverse
(),
number
<
Seq
::
size
()
-
r
[
number
<
2
>
{}]
-
1
>
{});
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/container/tuple.hpp
View file @
4d914af3
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
concat_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
embed_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
...
@@ -603,7 +623,7 @@ template <typename... Ys,
...
@@ -603,7 +623,7 @@ template <typename... Ys,
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
{
static_assert
(
X
::
S
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
static_assert
(
X
::
s
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
return
y
;
return
y
;
...
@@ -615,7 +635,7 @@ template <typename... Ys,
...
@@ -615,7 +635,7 @@ template <typename... Ys,
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
{
static_assert
(
X
::
S
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
static_assert
(
X
::
s
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
return
y
;
return
y
;
...
@@ -627,7 +647,7 @@ template <typename... Xs,
...
@@ -627,7 +647,7 @@ template <typename... Xs,
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
{
static_assert
(
Y
::
S
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
static_assert
(
Y
::
s
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
tuple
<
Xs
...
>
r
;
...
@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
...
@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return
r
;
return
r
;
}
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
+
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
{
static_assert
(
Y
::
S
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
static_assert
(
Y
::
s
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
tuple
<
Xs
...
>
r
;
...
@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
...
@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return
r
;
return
r
;
}
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
-
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
{
static_assert
(
Y
::
S
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
static_assert
(
Y
::
s
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
tuple
<
Xs
...
>
r
;
...
@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
...
@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return
a
*
x
;
return
a
*
x
;
}
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
*
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
...
Ys
>
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
{
...
...
include/ck_tile/core/numeric/int8.hpp
0 → 100644
View file @
4d914af3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace
ck_tile
{
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using
int8_t
=
int8_t
;
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
int8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
min
()
{
return
int8_t
(
-
128
);
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
lowest
()
{
return
int8_t
(
-
128
);
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
max
()
{
return
int8_t
(
127
);
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
zero
()
{
return
0
;
}
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr
float
int8_to_float
(
const
int8_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
int8_t
float_to_int8
(
const
float
&
x
)
{
return
static_cast
<
int8_t
>
(
x
);
}
}
// namespace ck_tile
include/ck_tile/core/numeric/math.hpp
View file @
4d914af3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -487,55 +487,12 @@ struct log2e<float>
...
@@ -487,55 +487,12 @@ struct log2e<float>
template
<
typename
T
=
double
>
template
<
typename
T
=
double
>
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
// math
CK_TILE_HOST_DEVICE
float
abs
(
const
float
&
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
float
&
x
)
{
uint32_t
xx
=
bit_cast
<
uint32_t
>
(
x
);
return
(
xx
&
0x7fffffff
)
>
0x7F800000
;
}
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__ocml_exp_f32
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
CK_TILE_DEVICE
CK_TILE_DEVICE
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
CK_TILE_HOST
CK_TILE_HOST
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
CK_TILE_DEVICE
float
log
(
float
x
)
{
return
__logf
(
x
);
};
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint16_t
sad_u16
(
uint16_t
x
,
uint16_t
y
,
uint16_t
acc
)
CK_TILE_DEVICE
uint16_t
sad_u16
(
uint16_t
x
,
uint16_t
y
,
uint16_t
acc
)
{
{
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
...
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
...
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
}
///////////////////////////////////////////////////////////////
}
// namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace
ck_tile
{
#if CK_TILE_WORKAROUND_SWDEV_383542
extern
"C"
CK_TILE_DEVICE
float
__ocml_native_recip_f32
(
float
);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
}
#endif
CK_TILE_HOST
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_HOST
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
std
::
sqrt
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acos
<
float
>
(
float
x
)
{
return
std
::
acosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acos
<
double
>
(
double
x
)
{
return
std
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
CK_TILE_HOST
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atan
<
float
>
(
float
x
)
{
return
std
::
atanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atan
<
double
>
(
double
x
)
{
return
std
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sin
<
float
>
(
float
x
)
{
return
std
::
sinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sin
<
double
>
(
double
x
)
{
return
std
::
sin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asin
<
float
>
(
float
x
)
{
return
std
::
asinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asin
<
double
>
(
double
x
)
{
return
std
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asinh
<
float
>
(
float
x
)
{
return
std
::
asinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asinh
<
double
>
(
double
x
)
{
return
std
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
cosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cos
<
float
>
(
float
x
)
{
return
std
::
cosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cos
<
double
>
(
double
x
)
{
return
std
::
cos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acosh
<
float
>
(
float
x
)
{
return
std
::
acoshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acosh
<
double
>
(
double
x
)
{
return
std
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tan
<
float
>
(
float
x
)
{
return
std
::
tanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tan
<
double
>
(
double
x
)
{
return
std
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atanh
<
float
>
(
float
x
)
{
return
std
::
atanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atanh
<
double
>
(
double
x
)
{
return
std
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sinh
<
float
>
(
float
x
)
{
return
std
::
sinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sinh
<
double
>
(
double
x
)
{
return
std
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
ceil
<
float
>
(
float
x
)
{
return
std
::
ceilf
(
x
);
};
template
<
>
CK_TILE_HOST
double
ceil
<
double
>
(
double
x
)
{
return
std
::
ceil
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cosh
<
float
>
(
float
x
)
{
return
std
::
coshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cosh
<
double
>
(
double
x
)
{
return
std
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
floor
<
float
>
(
float
x
)
{
return
std
::
floorf
(
x
);
};
template
<
>
CK_TILE_HOST
double
floor
<
double
>
(
double
x
)
{
return
std
::
floor
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
rcp
(
T
x
)
{
return
type_convert
<
T
>
(
1.
f
/
type_convert
<
float
>
(
x
));
};
template
<
typename
T
>
CK_TILE_HOST
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
template
<
>
CK_TILE_HOST
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
logf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
CK_TILE_HOST
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
std
::
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
}
template
<
>
CK_TILE_HOST
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
CK_TILE_HOST
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
CK_TILE_HOST
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expm1f
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
CK_TILE_HOST
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE
float
abs
(
float
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
};
CK_TILE_DEVICE
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
CK_TILE_DEVICE
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_DEVICE
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#endif
CK_TILE_DEVICE
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
CK_TILE_DEVICE
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_DEVICE
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_DEVICE
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
CK_TILE_DEVICE
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tanh
<
float
>
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acos
<
float
>
(
float
x
)
{
return
::
acosf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acos
<
double
>
(
double
x
)
{
return
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
fp16_t
neg
<
fp16_t
>
(
fp16_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atan
<
float
>
(
float
x
)
{
return
::
atanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atan
<
double
>
(
double
x
)
{
return
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sin
<
float
>
(
float
x
)
{
return
::
sinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sin
<
double
>
(
double
x
)
{
return
::
sin
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
sin
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_sin_f16
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asin
<
float
>
(
float
x
)
{
return
::
asinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asin
<
double
>
(
double
x
)
{
return
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asinh
<
float
>
(
float
x
)
{
return
::
asinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asinh
<
double
>
(
double
x
)
{
return
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acosh
<
float
>
(
float
x
)
{
return
::
acoshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acosh
<
double
>
(
double
x
)
{
return
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tan
<
float
>
(
float
x
)
{
return
::
tanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tan
<
double
>
(
double
x
)
{
return
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atanh
<
float
>
(
float
x
)
{
return
::
atanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atanh
<
double
>
(
double
x
)
{
return
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sinh
<
float
>
(
float
x
)
{
return
::
sinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sinh
<
double
>
(
double
x
)
{
return
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
ceil
<
float
>
(
float
x
)
{
return
::
ceilf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
ceil
<
double
>
(
double
x
)
{
return
::
ceil
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
ceil
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_ceil_f16
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
cosh
<
float
>
(
float
x
)
{
return
::
coshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
cosh
<
double
>
(
double
x
)
{
return
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
floor
<
float
>
(
float
x
)
{
return
::
floorf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
floor
<
double
>
(
double
x
)
{
return
::
floor
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
floor
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_floor_f16
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
rcp
(
T
x
)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
#else
// return __ocml_native_recip_f32(x);
return
__builtin_amdgcn_rcpf
(
x
);
#endif
};
template
<
typename
T
>
CK_TILE_DEVICE
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
__ocml_exp_f32
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
exp
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_exp_f16
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
exp
<
float
>
(
float
x
)
{
return
__ocml_exp_f32
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
__logf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
log
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_log_f16
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
};
template
<
>
CK_TILE_DEVICE
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
CK_TILE_DEVICE
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
expm1f
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace ck_tile
}
// namespace ck_tile
Prev
1
…
3
4
5
6
7
8
9
10
11
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment