Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7971bb5b
Commit
7971bb5b
authored
Aug 04, 2024
by
carlushuang
Browse files
add test for scatter/gather
parent
d311c953
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2688 additions
and
44 deletions
+2688
-44
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+200
-0
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+60
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+951
-28
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+33
-16
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+1
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+7
-0
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+1151
-0
test/CMakeLists.txt
test/CMakeLists.txt
+2
-0
test/scatter_gather/CMakeLists.txt
test/scatter_gather/CMakeLists.txt
+2
-0
test/scatter_gather/scatter_gather.cpp
test/scatter_gather/scatter_gather.cpp
+271
-0
No files found.
include/ck_tile/core.hpp
View file @
7971bb5b
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
7971bb5b
...
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate
,
xor_t
,
offset
,
indexing
,
};
template
<
index_t
NDimLow
,
index_t
NDimUp
>
...
...
@@ -1549,6 +1550,184 @@ struct offset : public base_transform<1, 1>
}
};
#if 0
template <typename UpLength,
typename Index>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
using Indices = decltype(make_tuple(Index{}));
UpLengths up_lengths_;
Indices indices_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const Index& index)
: up_lengths_{make_tuple(up_length)}, indices_{make_tuple(indices)}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = indices_[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /*idx_diff_up*/,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = 0;
//static_for<0, NDimUp, 1>{}(
// [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
// idx_low += idx_up;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Indices>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("indices_: ");
print(indices_);
printf("}");
}
};
#endif
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
struct
indexing
:
public
base_transform
<
1
,
1
>
{
static
constexpr
index_t
NDimUp
=
1
;
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
IndexingAdaptor
iadaptor_
;
CK_TILE_HOST_DEVICE
constexpr
indexing
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing
(
const
UpLength
&
up_length
,
const
IndexingAdaptor
&
iadaptor
)
:
up_lengths_
{
make_tuple
(
up_length
)},
iadaptor_
{
iadaptor
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
indexing
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
calculate_lower_index
(
idx_low
,
idx_up
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
IndexingAdaptor
::
is_known_at_compile_time
();
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"embed{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
}
};
//*******************************************************************************************************
template
<
typename
LowLength
>
...
...
@@ -1670,3 +1849,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
}
}
// namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace
ck_tile
{
template
<
typename
UpLength
,
typename
Indices
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform
(
const
UpLength
&
up_lengths
,
const
Indices
&
indices
)
{
// by default we use the simplest one
return
indexing
<
UpLength
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>>
{
up_lengths
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>
{
indices
}};
}
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform_with_adaptor
(
const
UpLength
&
up_lengths
,
const
IndexingAdaptor
&
iadaptor
)
{
return
indexing
<
UpLength
,
IndexingAdaptor
>
{
up_lengths
,
iadaptor
};
}
}
// namespace ck_tile
include/ck_tile/core/algorithm/indexing_adaptor.hpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template
<
typename
IndexingType
>
struct
indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
(
const
IndexingType
&
idx
)
:
cached_idx_
(
idx
)
{
}
IndexingType
cached_idx_
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
/*idx_up*/
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
cached_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*idx_low*/
,
const
UpIdx
&
/*idx_up*/
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
7971bb5b
...
...
@@ -31,12 +31,16 @@
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_EXTERN __host__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
...
...
@@ -191,3 +195,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
include/ck_tile/core/numeric/math.hpp
View file @
7971bb5b
...
...
@@ -41,9 +41,8 @@ struct scales
Scale
lhs_
;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template
<
typename
Scale
>
__host__
__device__
scales
(
Scale
)
->
scales
<
Scale
>
;
CK_TILE_HOST_DEVICE_EXTERN
scales
(
Scale
)
->
scales
<
Scale
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
plus
...
...
@@ -66,8 +65,7 @@ struct plus<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
plus
()
->
plus
<
void
,
void
>
;
CK_TILE_HOST_DEVICE_EXTERN
plus
()
->
plus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
minus
...
...
@@ -90,8 +88,7 @@ struct minus<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
minus
()
->
minus
<
void
,
void
>
;
CK_TILE_HOST_DEVICE_EXTERN
minus
()
->
minus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
multiplies
...
...
@@ -114,8 +111,7 @@ struct multiplies<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
multiplies
()
->
multiplies
<
void
,
void
>
;
CK_TILE_HOST_DEVICE_EXTERN
multiplies
()
->
multiplies
<
void
,
void
>
;
template
<
typename
T
>
struct
maximize
...
...
@@ -345,8 +341,7 @@ struct equal<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
equal
()
->
equal
<
void
,
void
>
;
CK_TILE_HOST_DEVICE_EXTERN
equal
()
->
equal
<
void
,
void
>
;
template
<
>
struct
equal
<
float
,
float
>
...
...
@@ -387,8 +382,7 @@ struct less<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
less
()
->
less
<
void
,
void
>
;
CK_TILE_HOST_DEVICE_EXTERN
less
()
->
less
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
less_equal
...
...
@@ -411,8 +405,7 @@ struct less_equal<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
less_equal
()
->
less_equal
<
void
,
void
>
;
CK_TILE_HOST_DEVICE_EXTERN
less_equal
()
->
less_equal
<
void
,
void
>
;
template
<
>
struct
less_equal
<
float
,
float
>
...
...
@@ -488,19 +481,19 @@ template <typename T = double>
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
// math
CK_TILE_HOST_DEVICE
float
abs
(
const
float
&
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
}
//
CK_TILE_HOST_DEVICE
//
float abs(const float& x)
//
{
//
union
//
{
//
float f32;
//
uint32_t u32;
//
} y;
//
y.f32 = x;
//
y.u32 = y.u32 & 0x7fffffff;
//
return y.f32;
//
}
#if 0
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
...
...
@@ -523,18 +516,20 @@ float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
#endif
CK_TILE_DEVICE
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
CK_TILE_HOST
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
#if 0
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
#endif
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
...
...
@@ -547,4 +542,932 @@ CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
///////////////////////////////////////////////////////////////
}
// namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace
ck_tile
{
#if CK_TILE_WORKAROUND_SWDEV_383542
extern
"C"
CK_TILE_DEVICE
float
__ocml_native_recip_f32
(
float
);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
}
#endif
CK_TILE_HOST
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_HOST
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
std
::
sqrt
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acos
<
float
>
(
float
x
)
{
return
std
::
acosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acos
<
double
>
(
double
x
)
{
return
std
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
CK_TILE_HOST
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atan
<
float
>
(
float
x
)
{
return
std
::
atanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atan
<
double
>
(
double
x
)
{
return
std
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sin
<
float
>
(
float
x
)
{
return
std
::
sinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sin
<
double
>
(
double
x
)
{
return
std
::
sin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asin
<
float
>
(
float
x
)
{
return
std
::
asinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asin
<
double
>
(
double
x
)
{
return
std
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asinh
<
float
>
(
float
x
)
{
return
std
::
asinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asinh
<
double
>
(
double
x
)
{
return
std
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
cosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cos
<
float
>
(
float
x
)
{
return
std
::
cosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cos
<
double
>
(
double
x
)
{
return
std
::
cos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acosh
<
float
>
(
float
x
)
{
return
std
::
acoshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acosh
<
double
>
(
double
x
)
{
return
std
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tan
<
float
>
(
float
x
)
{
return
std
::
tanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tan
<
double
>
(
double
x
)
{
return
std
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atanh
<
float
>
(
float
x
)
{
return
std
::
atanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atanh
<
double
>
(
double
x
)
{
return
std
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sinh
<
float
>
(
float
x
)
{
return
std
::
sinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sinh
<
double
>
(
double
x
)
{
return
std
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
ceil
<
float
>
(
float
x
)
{
return
std
::
ceilf
(
x
);
};
template
<
>
CK_TILE_HOST
double
ceil
<
double
>
(
double
x
)
{
return
std
::
ceil
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cosh
<
float
>
(
float
x
)
{
return
std
::
coshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cosh
<
double
>
(
double
x
)
{
return
std
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
floor
<
float
>
(
float
x
)
{
return
std
::
floorf
(
x
);
};
template
<
>
CK_TILE_HOST
double
floor
<
double
>
(
double
x
)
{
return
std
::
floor
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
rcp
(
T
x
)
{
return
type_convert
<
T
>
(
1.
f
/
type_convert
<
float
>
(
x
));
};
template
<
typename
T
>
CK_TILE_HOST
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
template
<
>
CK_TILE_HOST
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
logf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
CK_TILE_HOST
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
std
::
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
}
template
<
>
CK_TILE_HOST
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
CK_TILE_HOST
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
CK_TILE_HOST
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expm1f
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
CK_TILE_HOST
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE
float
abs
(
float
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
};
CK_TILE_DEVICE
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
CK_TILE_DEVICE
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_DEVICE
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#endif
CK_TILE_DEVICE
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
CK_TILE_DEVICE
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_DEVICE
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_DEVICE
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
CK_TILE_DEVICE
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tanh
<
float
>
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acos
<
float
>
(
float
x
)
{
return
::
acosf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acos
<
double
>
(
double
x
)
{
return
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
fp16_t
neg
<
fp16_t
>
(
fp16_t
x
)
{
return
__hneg
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atan
<
float
>
(
float
x
)
{
return
::
atanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atan
<
double
>
(
double
x
)
{
return
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sin
<
float
>
(
float
x
)
{
return
::
sinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sin
<
double
>
(
double
x
)
{
return
::
sin
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
sin
<
fp16_t
>
(
fp16_t
x
)
{
return
::
hsin
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asin
<
float
>
(
float
x
)
{
return
::
asinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asin
<
double
>
(
double
x
)
{
return
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asinh
<
float
>
(
float
x
)
{
return
::
asinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asinh
<
double
>
(
double
x
)
{
return
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acosh
<
float
>
(
float
x
)
{
return
::
acoshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acosh
<
double
>
(
double
x
)
{
return
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tan
<
float
>
(
float
x
)
{
return
::
tanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tan
<
double
>
(
double
x
)
{
return
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atanh
<
float
>
(
float
x
)
{
return
::
atanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atanh
<
double
>
(
double
x
)
{
return
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sinh
<
float
>
(
float
x
)
{
return
::
sinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sinh
<
double
>
(
double
x
)
{
return
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
ceil
<
float
>
(
float
x
)
{
return
::
ceilf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
ceil
<
double
>
(
double
x
)
{
return
::
ceil
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
ceil
<
fp16_t
>
(
fp16_t
x
)
{
return
::
hceil
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
cosh
<
float
>
(
float
x
)
{
return
::
coshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
cosh
<
double
>
(
double
x
)
{
return
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
floor
<
float
>
(
float
x
)
{
return
::
floorf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
floor
<
double
>
(
double
x
)
{
return
::
floor
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
floor
<
fp16_t
>
(
fp16_t
x
)
{
return
::
hfloor
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
rcp
(
T
x
)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
#else
return
__ocml_native_recip_f32
(
x
);
#endif
};
template
<
typename
T
>
CK_TILE_DEVICE
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
__ocml_exp_f32
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
exp
<
fp16_t
>
(
fp16_t
x
)
{
return
hexp
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
exp
<
float
>
(
float
x
)
{
return
__ocml_exp_f32
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
__logf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
log
<
fp16_t
>
(
fp16_t
x
)
{
return
hlog
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
};
template
<
>
CK_TILE_DEVICE
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
CK_TILE_DEVICE
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
expm1f
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
7971bb5b
...
...
@@ -17,6 +17,14 @@
namespace
ck_tile
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
return
Distribution
::
_get_partition_index
();
}
}
// namespace detail
// distributed span
template
<
index_t
...
PartialHsLengths
>
struct
tile_distributed_span
...
...
@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
auto
_get_partition_index
()
{
// only support warp-tile and block-tile
static_assert
(
NDimP
==
1
or
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
#if 0
...
...
@@ -149,6 +172,16 @@ struct tile_distribution
}
#endif
template
<
typename
PartitionIndex
=
decltype
(
_get_partition_index
())>
CK_TILE_HOST_DEVICE
auto
calculate_index
(
const
PartitionIndex
&
ps_idx
=
_get_partition_index
())
const
{
const
auto
ps_ys_idx
=
container_concat
(
ps_idx
,
array
<
index_t
,
NDimY
>
{
0
});
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
ps_ys_to_xs_
,
ps_ys_idx
);
return
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
...
...
@@ -500,22 +533,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
// only support warp-tile and block-tile
static_assert
(
Distribution
::
NDimP
==
1
or
Distribution
::
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
Distribution
::
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
Distribution
::
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
7971bb5b
...
...
@@ -41,6 +41,7 @@ struct tile_window_with_static_distribution
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static_assert
(
NumCoord
==
1
);
// TODO: check WindowLengths and StaticTileDistribution are consistent
...
...
include/ck_tile/ops/elementwise.hpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
element_wise
{
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
float
>
(
double
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
double
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int32_t
,
int32_t
>
(
int32_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
bf16_t
>
(
float
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
fp16_t
>
(
float
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
int8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
int8_t
>
(
ck_tile
::
bf16_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
uint8_t
,
uint8_t
>
(
uint8_t
&
y
,
const
uint8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int32_t
,
int8_t
>
(
int32_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
int32_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
float
>
(
int8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
int8_t
>
(
float
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int4_t
,
int
>
(
int4_t
&
y
,
const
int
&
x
)
const
{
y
=
type_convert
<
int4_t
>
(
x
);
}
#endif
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
>
(
ck_tile
::
fp8_t
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
fp8_t
>
(
float
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp8_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
>
(
ck_tile
::
bf8_t
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
ck_tile
::
bf8_t
>
(
float
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
float
>
(
ck_tile
::
bf8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
bf8_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
bf8_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
bf8_t
&
x
)
const
{
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
x
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf8_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
bf8_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf8_t
>
(
x
);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::bf16_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct
Scale
{
CK_TILE_HOST_DEVICE
Scale
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
Y
>
(
ck_tile
::
type_convert
<
float
>
(
x
)
*
scale_
);
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
ck_tile
::
fp16_t
>
(
scale_
)
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
const
float
x_tmp
=
ck_tile
::
type_convert
<
float
>
(
x
);
const
float
y_tmp
=
scale_
*
x_tmp
;
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf16_t
>
(
y_tmp
);
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
scale_
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
scale_
*
x
;
};
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
ck_tile
::
type_convert
<
int8_t
>
(
scale_
*
ck_tile
::
type_convert
<
float
>
(
x
));
};
float
scale_
;
};
struct
ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE
ScaleAndResetNaNToMinusInfinity
(
float
scale
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
ck_tile
::
isnan
(
x
)
?
-
ck_tile
::
NumericLimits
<
float
>::
Infinity
()
:
scale_
*
x
;
};
float
scale_
;
};
struct
UnaryDivide
{
CK_TILE_HOST_DEVICE
UnaryDivide
(
const
int32_t
divider
=
1
)
:
divider_
(
divider
)
{}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
int32_t
divider_
=
1
;
};
struct
UnarySquare
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
is_same_v
<
T
,
int4_t
>
#endif
,
"Data type is not supported by this operation!"
);
y
=
x
*
x
;
};
};
struct
UnaryAbs
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
abs
(
x
);
};
};
struct
UnarySqrt
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sqrt
(
x
);
};
};
struct
Relu
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
>
0
?
x
:
0
;
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
x_f32
=
ck_tile
::
type_convert
<
float
>
(
x
);
float
y_f32
=
x_f32
>
0
?
x_f32
:
0
;
y
=
ck_tile
::
type_convert
<
ck_tile
::
bf16_t
>
(
y_f32
);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
CK_TILE_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
CK_TILE_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
__ocml_exp_f32
(
u
);
y
=
x
*
ck_tile
::
rcp
(
1.
f
+
emu
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
float
>
(
ck_tile
::
fp16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
fp16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
float
>
(
ck_tile
::
bf16_t
&
y
,
const
float
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
x
);
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_DEVICE
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
>
(
ck_tile
::
bf16_t
&
y
,
const
ck_tile
::
bf16_t
&
x
)
const
{
float
y_f
;
this
->
operator
()
<
float
,
float
>
(
y_f
,
type_convert
<
float
>
(
x
));
y
=
type_convert
<
ck_tile
::
bf16_t
>
(
y_f
);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
0.5
f
*
x
*
(
1.
f
+
erf
(
float
(
0.70710678118
f
*
x
)));
}
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
>
(
ck_tile
::
fp16_t
&
y
,
const
ck_tile
::
fp16_t
&
x
)
const
{
y
=
ck_tile
::
fp16_t
(
0.5
)
*
x
*
(
ck_tile
::
fp16_t
(
1
)
+
ck_tile
::
fp16_t
(
erf
(
float
(
0.70710678118
f
*
x
))));
}
};
struct
Sigmoid
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
one
/
(
one
+
ck_tile
::
exp
(
-
x
));
};
};
struct
Silu
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
ck_tile
::
fp16_t
>
||
is_same_v
<
T
,
int8_t
>
||
is_same_v
<
T
,
int32_t
>
,
"Data type is not supported by this operation!"
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
x
*
(
one
/
(
one
+
ck_tile
::
exp
(
-
x
)));
};
};
struct
TanH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tanh
(
x
);
};
};
struct
ACos
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acos
(
x
);
};
};
struct
Neg
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
neg
(
x
);
};
};
struct
ATan
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atan
(
x
);
};
};
struct
Sin
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sin
(
x
);
};
};
struct
ASinH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asinh
(
x
);
};
};
struct
Cos
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cos
(
x
);
};
};
struct
ACosH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
acosh
(
x
);
};
};
struct
Tan
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
tan
(
x
);
};
};
struct
ATanH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
atanh
(
x
);
};
};
struct
SinH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
sinh
(
x
);
};
};
struct
Ceil
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
ceil
(
x
);
};
};
struct
Exp
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
exp
(
x
);
};
};
struct
CosH
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
cosh
(
x
);
};
};
struct
Floor
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
floor
(
x
);
};
};
struct
Log
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
log
(
x
);
};
};
struct
ASin
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
asin
(
x
);
};
};
struct
Rcp
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck_tile
::
rcp
(
x
);
};
};
struct
Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
X
,
float
>::
value
||
ck_tile
::
is_same
<
X
,
double
>::
value
||
ck_tile
::
is_same
<
X
,
ck_tile
::
fp16_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
ck_tile
::
is_same
<
Y
,
float
>::
value
||
ck_tile
::
is_same
<
Y
,
double
>::
value
||
ck_tile
::
is_same
<
Y
,
ck_tile
::
fp16_t
>::
value
,
"Data type is not supported by this operation!"
);
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck_tile
::
exp
(
bx
)));
};
const
float
beta_
;
};
struct
SoftRelu
{
SoftRelu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
ck_tile
::
log
(
one
+
ck_tile
::
exp
(
x
*
casted_alpha
))
/
casted_alpha
;
}
const
float
alpha_
;
};
struct
Power
{
Power
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
,
float
gamma
=
2.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
),
gamma_
(
gamma
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
T
casted_gamma
=
type_convert
<
T
>
(
gamma_
);
T
shifted_scaled_x
=
casted_alpha
+
casted_beta
*
x
;
y
=
ck_tile
::
pow
(
shifted_scaled_x
,
casted_gamma
);
}
const
float
alpha_
;
const
float
beta_
;
const
float
gamma_
;
};
struct
ClippedRelu
{
ClippedRelu
(
float
alpha
=
0.
f
,
float
beta
=
1.
f
)
:
alpha_
(
alpha
),
beta_
(
beta
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
T
casted_beta
=
type_convert
<
T
>
(
beta_
);
y
=
ck_tile
::
min
(
casted_beta
,
ck_tile
::
max
(
casted_alpha
,
x
));
}
const
float
alpha_
;
const
float
beta_
;
};
struct
LeakyRelu
{
LeakyRelu
(
float
alpha
=
0.01
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>=
0
?
x
:
x
*
casted_alpha
;
}
const
float
alpha_
;
};
struct
Elu
{
Elu
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
y
=
x
>
0
?
x
:
casted_alpha
*
ck_tile
::
expm1
(
x
);
}
const
float
alpha_
;
};
struct
Logistic
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
CK_TILE_HOST_DEVICE
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck_tile
::
is_same
<
T
,
float
>::
value
||
ck_tile
::
is_same
<
T
,
double
>::
value
||
ck_tile
::
is_same
<
T
,
ck_tile
::
fp16_t
>::
value
||
ck_tile
::
is_same
<
T
,
int32_t
>::
value
||
ck_tile
::
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck_tile
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
c
/
scale_in_
/
scale_wei_
/
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
c
*
scale_in_
*
scale_wei_
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
CK_TILE_HOST_DEVICE
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
CK_TILE_HOST_DEVICE
void
operator
()
<
ck_tile
::
fp8_t
,
float
>
(
ck_tile
::
fp8_t
&
e
,
const
float
&
c
)
const
{
float
x
;
Relu
{}.
template
operator
()
<
float
>(
x
,
c
*
scale_in_
*
scale_wei_
);
e
=
type_convert
<
ck_tile
::
fp8_t
>
(
x
*
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
// support fastconvert of int8 to fp16
template
<
typename
InputDataType
,
typename
OutputDataType
,
index_t
RegPackNumber
>
struct
FastNumericArrayConverter
{
};
template
<
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck_tile
::
fp16_t
,
4
>
{
using
InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
OutputArray
=
vector_type
<
ck_tile
::
fp16_t
,
4
>
;
CK_TILE_DEVICE
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
OutputArray
Output
;
uint32_t
*
half_2
=
reinterpret_cast
<
uint32_t
*>
(
&
Output
);
uint32_t
const
uint8_4
=
reinterpret_cast
<
uint32_t
const
&>
(
Input
);
static
constexpr
uint32_t
byte_selector_01
=
0x05010500
;
static
constexpr
uint32_t
byte_selector_23
=
0x05030502
;
static
constexpr
uint32_t
fp16_adder
=
0x64646464
;
half_2
[
0
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_01
);
half_2
[
1
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_23
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
:
"=v"
(
half_2
[
0
])
:
"v"
(
half_2
[
0
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
:
"=v"
(
half_2
[
1
])
:
"v"
(
half_2
[
1
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
return
Output
;
}
CK_TILE_DEVICE
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
template
<
index_t
N
>
struct
FastNumericArrayConverter
<
uint8_t
,
ck_tile
::
fp16_t
,
N
>
{
static
constexpr
int
VEC_WIDTH
=
4
;
static_assert
(
!
(
N
%
VEC_WIDTH
),
"N must be multiple of 4."
);
using
InputArray
=
vector_type
<
uint8_t
,
N
>
;
using
OutputArray
=
vector_type
<
ck_tile
::
fp16_t
,
N
>
;
CK_TILE_DEVICE
static
OutputArray
convert
(
InputArray
const
&
Input
)
{
FastNumericArrayConverter
<
uint8_t
,
ck_tile
::
fp16_t
,
4
>
converter
;
OutputArray
Output
;
using
Vec_InputArray
=
vector_type
<
uint8_t
,
4
>
;
using
Vec_OutputArray
=
vector_type
<
ck_tile
::
fp16_t
,
4
>
;
Vec_OutputArray
*
half_4_ptr
=
reinterpret_cast
<
Vec_OutputArray
*>
(
&
Output
);
Vec_InputArray
const
*
uint8_4_ptr
=
reinterpret_cast
<
Vec_InputArray
const
*>
(
&
Input
);
static_for
<
0
,
N
/
VEC_WIDTH
,
1
>
{}(
[
&
](
auto
i
)
{
half_4_ptr
[
i
]
=
converter
(
uint8_4_ptr
[
i
]);
});
return
Output
;
}
CK_TILE_DEVICE
OutputArray
operator
()(
InputArray
const
&
Input
)
{
return
convert
(
Input
);
}
};
}
// namespace element_wise
}
// namespace ck_tile
test/CMakeLists.txt
View file @
7971bb5b
...
...
@@ -217,3 +217,5 @@ if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_
add_subdirectory
(
smfmac_op
)
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
test/scatter_gather/CMakeLists.txt
0 → 100644
View file @
7971bb5b
add_test_executable
(
test_scatter_gather scatter_gather.cpp
)
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
test/scatter_gather/scatter_gather.cpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#ifndef TEST_SCATTER_GATHER_VERBOSE
#define TEST_SCATTER_GATHER_VERBOSE 0
#endif
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
template
<
ck_tile
::
index_t
ROW_TILE_SIZE
=
8
,
ck_tile
::
index_t
COL_TILE_SIZE
=
32
*
8
,
ck_tile
::
index_t
BLOCK_SIZE
=
256
,
ck_tile
::
index_t
ALIGNMENT
=
8
,
typename
INDEX_BUF_TYPE
=
ck_tile
::
index_t
,
typename
DATA_TYPE
=
ck_tile
::
fp16_t
>
__global__
void
row_scatter_gather
(
const
INDEX_BUF_TYPE
*
src_row_idx_ptr
,
const
INDEX_BUF_TYPE
*
dst_row_idx_ptr
,
const
DATA_TYPE
*
src_ptr
,
DATA_TYPE
*
dst_ptr
,
ck_tile
::
index_t
n_row_total
,
ck_tile
::
index_t
/*n_row_select*/
,
ck_tile
::
index_t
n_cols
)
{
using
namespace
ck_tile
;
// some constexpr vars
constexpr
index_t
vec
=
ALIGNMENT
;
static_assert
(
COL_TILE_SIZE
%
vec
==
0
);
constexpr
index_t
col_lanes
=
COL_TILE_SIZE
/
vec
;
constexpr
index_t
warp_size
=
ck_tile
::
get_warp_size
();
static_assert
(
warp_size
%
col_lanes
==
0
);
constexpr
index_t
row_lanes
=
warp_size
/
col_lanes
;
constexpr
index_t
num_warps
=
BLOCK_SIZE
/
warp_size
;
static_assert
(
ROW_TILE_SIZE
%
(
num_warps
*
row_lanes
)
==
0
);
constexpr
index_t
row_repeat
=
ROW_TILE_SIZE
/
(
num_warps
*
row_lanes
);
static_assert
(
row_repeat
==
1
,
"currently indexing not support(and would be not performant) if row_repeat has more"
);
// tile partitioner
index_t
tile_col_idx
=
0
;
index_t
tile_row_idx
=
blockIdx
.
x
*
ROW_TILE_SIZE
;
// create our tild distribution, which tell us the location of different threads
constexpr
auto
src_dist
=
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
col_lanes
,
vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
const
auto
coord
=
src_dist
.
calculate_index
();
const
auto
row_coord
=
coord
[
number
<
0
>
{}]
+
tile_row_idx
;
// load the current row index from the indexing buffer. we do not use ck_tile utility here
INDEX_BUF_TYPE
src_row_id
=
src_row_idx_ptr
[
row_coord
];
INDEX_BUF_TYPE
dst_row_id
=
dst_row_idx_ptr
[
row_coord
];
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
const
auto
src_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
src_ptr
,
make_tuple
(
n_row_total
,
n_cols
),
make_tuple
(
n_cols
,
1
),
number
<
vec
>
{},
// alignement
number
<
1
>
{});
const
auto
src_gather_view
=
transform_tensor_view
(
src_view
,
make_tuple
(
make_indexing_transform
(
n_row_total
,
src_row_id
),
// here we replace row_idx which is loaded from another buffer
make_pass_through_transform
(
n_cols
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
src_tile
=
make_tile_window
(
src_gather_view
,
make_tuple
(
number
<
ROW_TILE_SIZE
>
{},
number
<
COL_TILE_SIZE
>
{}),
{
tile_row_idx
,
tile_col_idx
},
src_dist
);
const
auto
dst_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dst_ptr
,
make_tuple
(
n_row_total
,
n_cols
),
make_tuple
(
n_cols
,
1
),
number
<
vec
>
{},
number
<
1
>
{});
const
auto
dst_scatter_view
=
transform_tensor_view
(
dst_view
,
make_tuple
(
make_indexing_transform
(
n_row_total
,
dst_row_id
),
// here we replace row_idx which is loaded from another buffer
make_pass_through_transform
(
n_cols
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
dst_tile
=
make_tile_window
(
dst_scatter_view
,
make_tuple
(
number
<
ROW_TILE_SIZE
>
{},
number
<
COL_TILE_SIZE
>
{}),
{
tile_row_idx
,
tile_col_idx
},
src_dist
/*reuse distribution*/
);
// we finished descriptor construction and index calculation, now start load/store
for
(
auto
i
=
0
;
i
<
n_cols
;
i
+=
COL_TILE_SIZE
)
{
// note that scatter/gather are just the same API when doing load store as normal memory
// operation
auto
data
=
load_tile
(
src_tile
);
store_tile
(
dst_tile
,
data
);
move_tile_window
(
src_tile
,
{
0
,
COL_TILE_SIZE
});
move_tile_window
(
dst_tile
,
{
0
,
COL_TILE_SIZE
});
}
}
union
pixel
{
struct
__attribute__
((
packed
))
{
unsigned
int
r
:
6
;
unsigned
int
c
:
10
;
};
ushort
data
;
};
struct
unique_linear_rand
{
unique_linear_rand
(
int
capacity_
)
:
capacity
(
capacity_
)
{}
std
::
unordered_set
<
int
>
set
;
int
gen
()
{
if
(
static_cast
<
int
>
(
set
.
size
())
>=
capacity
)
{
printf
(
"overflow, but will give you an number as well
\n
"
);
return
std
::
rand
()
%
capacity
;
}
while
(
1
)
{
int
r
=
std
::
rand
()
%
capacity
;
if
(
set
.
count
(
r
)
==
1
)
{
continue
;
}
set
.
insert
(
r
);
return
r
;
}
}
int
capacity
;
};
int
main
()
{
int
row_total
=
64
;
int
row_select
=
8
*
2
;
int
col
=
256
*
2
;
using
fp16_t
=
ck_tile
::
fp16_t
;
constexpr
int
row_tile
=
8
;
constexpr
int
col_tile
=
256
;
fp16_t
*
src
=
reinterpret_cast
<
fp16_t
*>
(
malloc
(
row_total
*
col
*
sizeof
(
fp16_t
)));
for
(
int
i_r
=
0
;
i_r
<
row_total
;
i_r
++
)
{
for
(
int
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
int
i
=
i_r
*
col
+
i_c
;
pixel
p
;
p
.
r
=
i_r
;
p
.
c
=
i_c
;
ushort
d
=
p
.
data
;
src
[
i
]
=
ck_tile
::
bit_cast
<
fp16_t
>
(
d
);
// for simplicity, just cast
}
}
fp16_t
*
dst
=
reinterpret_cast
<
fp16_t
*>
(
malloc
(
row_total
*
col
*
sizeof
(
fp16_t
)));
int
*
src_idx
=
reinterpret_cast
<
int
*>
(
malloc
(
row_select
*
sizeof
(
int
)));
int
*
dst_idx
=
reinterpret_cast
<
int
*>
(
malloc
(
row_select
*
sizeof
(
int
)));
// std::srand(std::time(std::nullptr));
// std::srand(11935);
std
::
srand
(
std
::
time
(
nullptr
));
auto
src_gen
=
unique_linear_rand
(
row_total
);
auto
dst_gen
=
unique_linear_rand
(
row_total
);
// dst index must be unique. src is fine
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
src_idx
[
i_r
]
=
src_gen
.
gen
();
dst_idx
[
i_r
]
=
dst_gen
.
gen
();
}
void
*
dev_src
;
void
*
dev_dst
;
void
*
dev_src_idx
;
void
*
dev_dst_idx
;
HIP_CALL
(
hipMalloc
(
&
dev_src
,
row_total
*
col
*
sizeof
(
fp16_t
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst
,
row_total
*
col
*
sizeof
(
fp16_t
)));
HIP_CALL
(
hipMalloc
(
&
dev_src_idx
,
row_select
*
sizeof
(
int
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst_idx
,
row_select
*
sizeof
(
int
)));
HIP_CALL
(
hipMemcpy
(
dev_src
,
src
,
row_total
*
col
*
sizeof
(
fp16_t
),
hipMemcpyHostToDevice
));
HIP_CALL
(
hipMemcpy
(
dev_src_idx
,
src_idx
,
row_select
*
sizeof
(
int
),
hipMemcpyHostToDevice
));
HIP_CALL
(
hipMemcpy
(
dev_dst_idx
,
dst_idx
,
row_select
*
sizeof
(
int
),
hipMemcpyHostToDevice
));
constexpr
int
bdim
=
256
;
int
gdim
=
(
row_select
+
row_tile
-
1
)
/
row_tile
;
row_scatter_gather
<
row_tile
,
col_tile
><<<
gdim
,
bdim
>>>
(
reinterpret_cast
<
int
*>
(
dev_src_idx
),
reinterpret_cast
<
int
*>
(
dev_dst_idx
),
reinterpret_cast
<
fp16_t
*>
(
dev_src
),
reinterpret_cast
<
fp16_t
*>
(
dev_dst
),
row_total
,
row_select
,
col
);
HIP_CALL
(
hipMemcpy
(
dst
,
dev_dst
,
row_total
*
col
*
sizeof
(
fp16_t
),
hipMemcpyDeviceToHost
));
#if TEST_SCATTER_GATHER_VERBOSE
printf
(
"select row:"
);
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
printf
(
"%d->%d->%d "
,
i_r
,
src_idx
[
i_r
],
dst_idx
[
i_r
]);
}
printf
(
"
\n
"
);
#endif
int
err_cnt
=
0
;
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
for
(
int
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
int
i
=
dst_idx
[
i_r
]
*
col
+
i_c
;
pixel
p
=
ck_tile
::
bit_cast
<
pixel
>
(
dst
[
i
]);
bool
is_ok
=
p
.
r
==
src_idx
[
i_r
]
&&
p
.
c
==
i_c
;
if
(
!
is_ok
)
{
if
(
i_c
==
0
)
printf
(
"(%d)pixel: %dx%d -> %d
\n
"
,
i_r
,
p
.
r
,
p
.
c
,
dst_idx
[
i_r
]);
err_cnt
++
;
}
}
}
#if TEST_SCATTER_GATHER_VERBOSE
printf
(
"err:%d
\n
"
,
err_cnt
);
#endif
free
(
src
);
free
(
dst
);
free
(
src_idx
);
free
(
dst_idx
);
return
err_cnt
==
0
?
0
:
-
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment