Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e599063f
Commit
e599063f
authored
May 10, 2024
by
illsilin
Browse files
sync from the public repo
parents
5dbbf5d6
566b6480
Changes
305
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1075 additions
and
0 deletions
+1075
-0
include/ck_tile/core/numeric/integral_constant.hpp
include/ck_tile/core/numeric/integral_constant.hpp
+83
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+550
-0
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+191
-0
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+66
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+185
-0
No files found.
Too many changes to show.
To preserve performance only
305 of 305+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/numeric/integral_constant.hpp
0 → 100644
View file @
e599063f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace
ck_tile
{
template
<
auto
v
>
struct
constant
{
using
value_type
=
decltype
(
v
);
using
type
=
constant
;
// using injected-class-name
static
constexpr
value_type
value
=
v
;
CK_TILE_HOST_DEVICE
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
CK_TILE_HOST_DEVICE
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
true
;
}
};
template
<
typename
T
,
T
v
>
struct
integral_constant
:
constant
<
v
>
{
using
value_type
=
T
;
using
type
=
integral_constant
;
// using injected-class-name
static
constexpr
T
value
=
v
;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template
<
index_t
v
>
using
number
=
constant
<
v
>
;
template
<
long_index_t
v
>
using
long_number
=
constant
<
v
>
;
template
<
bool
b
>
using
bool_constant
=
constant
<
b
>
;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
{ \
return constant<(OP x)>{}; \
}
#define CK_TILE_BINARY_OP(OP) \
template <auto x, auto y> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
{ \
return constant<(x OP y)>{}; \
}
CK_TILE_LEFT_UNARY_OP
(
+
)
CK_TILE_LEFT_UNARY_OP
(
-
)
CK_TILE_LEFT_UNARY_OP
(
~
)
CK_TILE_LEFT_UNARY_OP
(
!
)
CK_TILE_LEFT_UNARY_OP
(
*
)
CK_TILE_BINARY_OP
(
+
)
CK_TILE_BINARY_OP
(
-
)
CK_TILE_BINARY_OP
(
*
)
CK_TILE_BINARY_OP
(
/
)
CK_TILE_BINARY_OP
(
%
)
CK_TILE_BINARY_OP
(
&
)
CK_TILE_BINARY_OP
(
|
)
CK_TILE_BINARY_OP
(
^
)
CK_TILE_BINARY_OP
(
<<
)
CK_TILE_BINARY_OP
(
>>
)
CK_TILE_BINARY_OP
(
&&
)
CK_TILE_BINARY_OP
(
||
)
CK_TILE_BINARY_OP
(
==
)
CK_TILE_BINARY_OP
(
!=
)
CK_TILE_BINARY_OP
(
>
)
CK_TILE_BINARY_OP
(
<
)
CK_TILE_BINARY_OP
(
>=
)
CK_TILE_BINARY_OP
(
<=
)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
}
// namespace ck_tile
include/ck_tile/core/numeric/math.hpp
0 → 100644
View file @
e599063f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <type_traits>
#include <stdint.h>
#include <cmath>
namespace
ck_tile
{
template
<
typename
Scale
,
Scale
lhs
>
struct
scales_c
{
template
<
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Right
&
rhs
)
const
->
decltype
(
lhs
*
rhs
)
{
return
lhs
*
rhs
;
}
};
template
<
typename
Scale
>
struct
scales
{
static_assert
(
std
::
is_copy_constructible_v
<
Scale
>
);
CK_TILE_HOST_DEVICE
constexpr
explicit
scales
(
Scale
lhs
)
:
lhs_
(
lhs
)
{}
template
<
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Right
&
rhs
)
const
->
decltype
(
std
::
declval
<
const
Scale
&>
()
*
rhs
)
{
return
lhs_
*
rhs
;
}
private:
Scale
lhs_
;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template
<
typename
Scale
>
__host__
__device__
scales
(
Scale
)
->
scales
<
Scale
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
plus
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
+
rhs
)
{
return
lhs
+
rhs
;
}
};
template
<
>
struct
plus
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
+
rhs
)
{
return
lhs
+
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
plus
()
->
plus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
minus
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
-
rhs
)
{
return
lhs
-
rhs
;
}
};
template
<
>
struct
minus
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
-
rhs
)
{
return
lhs
-
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
minus
()
->
minus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
multiplies
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
*
rhs
)
{
return
lhs
*
rhs
;
}
};
template
<
>
struct
multiplies
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
*
rhs
)
{
return
lhs
*
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
multiplies
()
->
multiplies
<
void
,
void
>
;
template
<
typename
T
>
struct
maximize
{
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>=
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
minimize
{
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<=
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
integer_divide_ceiler
{
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
static_assert
(
std
::
is_same
<
T
,
index_t
>
{}
||
std
::
is_same
<
T
,
int
>
{},
"wrong type"
);
return
(
a
+
b
-
number
<
1
>
{})
/
b
;
}
};
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
integer_divide_floor
(
X
x
,
Y
y
)
{
return
x
/
y
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
{
return
(
x
+
y
-
number
<
1
>
{})
/
y
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
{
return
y
*
integer_divide_ceil
(
x
,
y
);
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
constexpr
float
max
(
float
x
,
float
y
)
{
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
}
template
<
>
CK_TILE_DEVICE
constexpr
double
max
(
double
x
,
double
y
)
{
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
index_t
max
(
number
<
X
>
,
index_t
y
)
{
return
X
>
y
?
X
:
y
;
}
template
<
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
index_t
max
(
index_t
x
,
number
<
Y
>
)
{
return
x
>
Y
?
x
:
Y
;
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
constexpr
float
min
(
float
x
,
float
y
)
{
return
__builtin_fminf
(
x
,
y
);
}
template
<
>
CK_TILE_DEVICE
constexpr
double
min
(
double
x
,
double
y
)
{
return
__builtin_fmin
(
x
,
y
);
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
index_t
min
(
number
<
X
>
,
index_t
y
)
{
return
X
<
y
?
X
:
y
;
}
template
<
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
index_t
min
(
index_t
x
,
number
<
Y
>
)
{
return
x
<
Y
?
x
:
Y
;
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
clamp
(
const
T
&
x
,
const
T
&
lowerbound
,
const
T
&
upperbound
)
{
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
CK_TILE_HOST
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
CK_TILE_DEVICE
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
if
(
x
<
0
)
{
return
gcd
(
-
x
,
y
);
}
else
if
(
y
<
0
)
{
return
gcd
(
x
,
-
y
);
}
else
if
(
x
==
y
||
x
==
0
)
{
return
y
;
}
else
if
(
y
==
0
)
{
return
x
;
}
else
if
(
x
>
y
)
{
return
gcd
(
x
%
y
,
y
);
}
else
{
return
gcd
(
x
,
y
%
x
);
}
}
template
<
index_t
X
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
gcd
(
number
<
X
>
,
number
<
Y
>
)
{
constexpr
auto
r
=
gcd
(
X
,
Y
);
return
number
<
r
>
{};
}
template
<
typename
X
,
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
return
gcd
(
x
,
gcd
(
ys
...));
}
// least common multiple
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
lcm
(
X
x
,
Y
y
)
{
return
(
x
*
y
)
/
gcd
(
x
,
y
);
}
template
<
typename
X
,
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
lcm
(
X
x
,
Ys
...
ys
)
{
return
lcm
(
x
,
lcm
(
ys
...));
}
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
equal
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
==
rhs
)
{
return
lhs
==
rhs
;
}
};
template
<
>
struct
equal
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
==
rhs
)
{
return
lhs
==
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
equal
()
->
equal
<
void
,
void
>
;
template
<
>
struct
equal
<
float
,
float
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
float
lhs
,
float
rhs
)
const
{
return
bit_cast
<
uint32_t
>
(
lhs
)
==
bit_cast
<
uint32_t
>
(
rhs
);
}
};
template
<
>
struct
equal
<
double
,
double
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
double
lhs
,
double
rhs
)
const
{
return
bit_cast
<
uint64_t
>
(
lhs
)
==
bit_cast
<
uint64_t
>
(
rhs
);
}
};
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
less
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<
rhs
)
{
return
lhs
<
rhs
;
}
};
template
<
>
struct
less
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<
rhs
)
{
return
lhs
<
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
less
()
->
less
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
less_equal
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<=
rhs
)
{
return
lhs
<=
rhs
;
}
};
template
<
>
struct
less_equal
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<=
rhs
)
{
return
lhs
<=
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
less_equal
()
->
less_equal
<
void
,
void
>
;
template
<
>
struct
less_equal
<
float
,
float
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
float
lhs
,
float
rhs
)
const
{
return
lhs
<
rhs
||
bit_cast
<
uint32_t
>
(
lhs
)
==
bit_cast
<
uint32_t
>
(
rhs
);
}
};
template
<
>
struct
less_equal
<
double
,
double
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
double
lhs
,
double
rhs
)
const
{
return
lhs
<
rhs
||
bit_cast
<
uint64_t
>
(
lhs
)
==
bit_cast
<
uint64_t
>
(
rhs
);
}
};
CK_TILE_HOST_DEVICE
constexpr
int32_t
next_power_of_two
(
int32_t
x
)
{
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
return
1
<<
(
32
-
clz
(
x
-
1
));
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
next_power_of_two
()
{
constexpr
index_t
y
=
next_power_of_two
(
X
);
return
number
<
y
>
{};
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
next_power_of_two
(
number
<
X
>
)
{
constexpr
index_t
y
=
next_power_of_two
(
X
);
return
number
<
y
>
{};
}
CK_TILE_HOST_DEVICE
constexpr
int32_t
integer_log2_floor
(
int32_t
x
)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return
31
-
__builtin_clz
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
bool
is_power_of_two_integer
(
int32_t
x
)
{
// TODO: x need to be 1 ~ 0x7fffffff
return
x
==
(
1
<<
integer_log2_floor
(
x
));
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
template
<
typename
T
>
struct
log2e
;
template
<
>
struct
log2e
<
double
>
{
static
constexpr
double
value
=
C_LOG2E
;
};
template
<
>
struct
log2e
<
float
>
{
static
constexpr
float
value
=
C_LOG2E
;
};
template
<
typename
T
=
double
>
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
// math
CK_TILE_HOST_DEVICE
float
abs
(
const
float
&
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
float
&
x
)
{
uint32_t
xx
=
bit_cast
<
uint32_t
>
(
x
);
return
(
xx
&
0x7fffffff
)
>
0x7F800000
;
}
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__expf
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
CK_TILE_DEVICE
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
CK_TILE_HOST
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
CK_TILE_DEVICE
float
log
(
float
x
)
{
return
__logf
(
x
);
};
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
}
// namespace ck_tile
include/ck_tile/core/numeric/numeric.hpp
0 → 100644
View file @
e599063f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace
ck_tile
{
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
//
template
<
typename
T
>
struct
numeric
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
T
min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
T
lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
T
max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
T
epsilon
()
{
return
std
::
numeric_limits
<
T
>::
epsilon
();
}
// maximum rounding error
CK_TILE_HOST_DEVICE
static
constexpr
T
round_error
()
{
return
std
::
numeric_limits
<
T
>::
round_error
();
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
T
infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
T
quiet_NaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
T
signaling_NaN
()
{
return
std
::
numeric_limits
<
T
>::
signaling_NaN
();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
T
denorm_min
()
{
return
std
::
numeric_limits
<
T
>::
denorm_min
();
}
CK_TILE_HOST_DEVICE
static
constexpr
T
zero
()
{
return
static_cast
<
T
>
(
0
);
}
CK_TILE_HOST_DEVICE
static
constexpr
T
one
()
{
return
static_cast
<
T
>
(
1
);
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
CK_TILE_HOST_DEVICE
static
constexpr
T
log2e
()
{
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
)
{
return
static_cast
<
T
>
(
C_LOG2E
);
}
else
{
return
0
;
// TODO: integer?
}
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
}
// namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}
include/ck_tile/core/numeric/type_convert.hpp
0 → 100644
View file @
e599063f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace
ck_tile
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
Y
>
type_convert
(
const
X
&
x
)
{
return
static_cast
<
Y
>
(
x
);
}
#else
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
using
non_const_y
=
std
::
remove_const_t
<
Y
>
;
using
non_const_x
=
std
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
non_const_y
,
non_const_x
>
(
x
));
}
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x); \
}
CK_TILE_TYPE_CONVERT
(
float
,
float
,
fp16_t
,
fp16
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
bf16_t
,
bf16
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
fp8_t
,
fp8
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
bf8_t
,
bf8
)
CK_TILE_TYPE_CONVERT
(
fp16_t
,
fp16
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf16_t
,
bf16
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
#undef CK_TILE_TYPE_CONVERT
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
0 → 100644
View file @
e599063f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template
<
typename
T
>
struct
native_t
{
using
type
=
remove_cvref_t
<
T
>
;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace
impl
{
template
<
typename
T_
,
index_t
N_
>
struct
ext_vector
{
static
constexpr
index_t
N
=
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
T_
>>::
type
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
template
<
typename
V_
,
index_t
Vs_
,
index_t
N_
>
struct
ext_vector
<
V_
__attribute__
((
ext_vector_type
(
Vs_
))),
N_
>
{
static
constexpr
index_t
N
=
Vs_
*
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
V_
>>::
type
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
}
// namespace impl
template
<
typename
T
,
index_t
N
>
using
ext_vector_t
=
typename
impl
::
ext_vector
<
T
,
N
>::
type
;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template
<
typename
T
>
struct
vector_traits
{
using
scalar_type
=
remove_cvref_t
<
T
>
;
static
constexpr
index_t
vector_size
=
1
;
};
// specialization for ext_vector_type()
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
{
using
scalar_type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
typename
X
,
typename
Y
>
using
has_same_scalar_type
=
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
Y
>>::
scalar_type
>
;
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using
fp64_t
=
double
;
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
// fp32
using
fp32_t
=
float
;
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
using
fp32x16_t
=
float
__attribute__
((
ext_vector_type
(
16
)));
using
fp32x32_t
=
float
__attribute__
((
ext_vector_type
(
32
)));
using
fp32x64_t
=
float
__attribute__
((
ext_vector_type
(
64
)));
// fp16
// using fp16_t = ...
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
using
fp16x16_t
=
_Float16
__attribute__
((
ext_vector_type
(
16
)));
using
fp16x32_t
=
_Float16
__attribute__
((
ext_vector_type
(
32
)));
using
fp16x64_t
=
_Float16
__attribute__
((
ext_vector_type
(
64
)));
// bf16
// using bf16_t = ...
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
using
bf16x16_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
16
)));
using
bf16x32_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
32
)));
using
bf16x64_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
64
)));
// i32
// using int32_t = ...
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
using
int32x16_t
=
int32_t
__attribute__
((
ext_vector_type
(
16
)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
using
int16x16_t
=
int16_t
__attribute__
((
ext_vector_type
(
16
)));
using
int16x32_t
=
int16_t
__attribute__
((
ext_vector_type
(
32
)));
using
int16x64_t
=
int16_t
__attribute__
((
ext_vector_type
(
64
)));
// u16
// using uint16_t
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint16x16_t
=
uint16_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint16x32_t
=
uint16_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint16x64_t
=
uint16_t
__attribute__
((
ext_vector_type
(
64
)));
// i8
// using int8_t
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
using
int8x16_t
=
int8_t
__attribute
((
ext_vector_type
(
16
)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x16_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
16
)));
using
fp8x32_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
32
)));
using
fp8x64_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
64
)));
// bf8
// using bf8_t
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x16_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
16
)));
using
bf8x32_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
32
)));
using
bf8x64_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
64
)));
#else
// f8
// using fp8_t
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x16_t
=
fp8_t
__attribute
((
ext_vector_type
(
16
)));
using
fp8x32_t
=
fp8_t
__attribute
((
ext_vector_type
(
32
)));
using
fp8x64_t
=
fp8_t
__attribute
((
ext_vector_type
(
64
)));
// bf8
// using bf8_t
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x16_t
=
bf8_t
__attribute
((
ext_vector_type
(
16
)));
using
bf8x32_t
=
bf8_t
__attribute
((
ext_vector_type
(
32
)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
}
// namespace ck_tile
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment