Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4396a224
Unverified
Commit
4396a224
authored
Apr 16, 2024
by
Harisankar Sadasivan
Committed by
GitHub
Apr 16, 2024
Browse files
Merge branch 'develop' into mi300_time_measurement
parents
0a27f07e
501a6b68
Changes
187
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6372 additions
and
0 deletions
+6372
-0
include/ck_tile/core/container/sequence.hpp
include/ck_tile/core/container/sequence.hpp
+1114
-0
include/ck_tile/core/container/span.hpp
include/ck_tile/core/container/span.hpp
+78
-0
include/ck_tile/core/container/statically_indexed_array.hpp
include/ck_tile/core/container/statically_indexed_array.hpp
+41
-0
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+165
-0
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+781
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+342
-0
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+871
-0
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+385
-0
include/ck_tile/core/numeric/integer.hpp
include/ck_tile/core/numeric/integer.hpp
+13
-0
include/ck_tile/core/numeric/integral_constant.hpp
include/ck_tile/core/numeric/integral_constant.hpp
+83
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+539
-0
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+191
-0
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+66
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+185
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+1068
-0
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+81
-0
include/ck_tile/core/tensor/null_tensor.hpp
include/ck_tile/core/tensor/null_tensor.hpp
+12
-0
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+88
-0
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+177
-0
include/ck_tile/core/tensor/slice_tile.hpp
include/ck_tile/core/tensor/slice_tile.hpp
+92
-0
No files found.
include/ck_tile/core/container/sequence.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
template
<
index_t
,
index_t
,
index_t
>
struct
static_for
;
template
<
index_t
...>
struct
sequence
;
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
;
template
<
typename
>
struct
sequence_reverse
;
template
<
typename
>
struct
sequence_map_inverse
;
template
<
typename
>
struct
is_valid_sequence_map
;
template
<
index_t
I
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_front
(
sequence
<
I
,
Is
...
>
);
template
<
typename
Seq
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_back
(
Seq
);
namespace
impl
{
// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element");
template
<
index_t
I
,
typename
...
Ts
>
using
at_index_t
=
__type_pack_element
<
I
,
Ts
...
>
;
}
// namespace impl
// we could implement as below, similiar to std. But let's reduce the symbol name...
// template< class T, T... Ints >
// class integer_sequence;
template
<
index_t
...
Is
>
struct
sequence
{
using
type
=
sequence
;
using
value_type
=
index_t
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
size
()
{
return
sizeof
...(
Is
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
true
;
};
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get
()
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
impl
::
at_index_t
<
I
,
constant
<
Is
>
...
>
{}
>
{};
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get
(
number
<
I
>
)
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
get
<
I
>
()
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
at
(
index_t
I
)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
size
()
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
at
()
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
impl
::
at_index_t
<
I
,
constant
<
Is
>
...
>
{}
>
{};
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
at
(
number
<
I
>
)
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
get
<
I
>
()
>
{};
}
template
<
typename
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
[](
I
i
)
const
{
return
at
(
i
);
}
template
<
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
reorder_new_to_old
(
sequence
<
IRs
...
>
/*new2old*/
)
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
IRs
),
"wrong! reorder map should have the same size as sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
sequence
<
IRs
...
>>::
value
,
"wrong! invalid reorder map"
);
return
sequence
<
type
::
get
(
number
<
IRs
>
{})...
>
{};
}
// MapOld2New is sequence<...>
template
<
typename
MapOld2New
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
reorder_old_to_new
(
MapOld2New
)
{
static_assert
(
MapOld2New
::
size
()
==
size
(),
"wrong! reorder map should have the same size as sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
return
reorder_new_to_old
(
typename
sequence_map_inverse
<
MapOld2New
>::
type
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
reverse
()
{
return
typename
sequence_reverse
<
type
>::
type
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
front
()
{
static_assert
(
size
()
>
0
,
"wrong!"
);
return
get
(
number
<
0
>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
back
()
{
static_assert
(
size
()
>
0
,
"wrong!"
);
return
get
(
number
<
size
()
-
1
>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
pop_front
()
{
return
sequence_pop_front
(
type
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
pop_back
()
{
return
sequence_pop_back
(
type
{});
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_front
(
sequence
<
Xs
...
>
)
{
return
sequence
<
Xs
...,
Is
...
>
{};
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_front
(
number
<
Xs
>
...)
{
return
sequence
<
Xs
...,
Is
...
>
{};
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_back
(
sequence
<
Xs
...
>
)
{
return
sequence
<
Is
...,
Xs
...
>
{};
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_back
(
number
<
Xs
>
...)
{
return
sequence
<
Is
...,
Xs
...
>
{};
}
// pickup element at index <Ids...>
template
<
index_t
...
Ids
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
extract
(
number
<
Ids
>
...)
{
return
sequence
<
type
::
get
(
number
<
Ids
>
{})...
>
{};
}
template
<
index_t
...
Ids
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
extract
(
sequence
<
Ids
...
>
)
{
return
sequence
<
type
::
get
(
number
<
Ids
>
{})...
>
{};
}
// modify element at index "I" with value "X"
template
<
index_t
I
,
index_t
X
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
modify
(
number
<
I
>
,
number
<
X
>
)
{
static_assert
(
I
<
size
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
left_type
{};
constexpr
auto
seq_right
=
typename
seq_split
::
right_type
{}.
pop_front
();
return
seq_left
.
push_back
(
number
<
X
>
{}).
push_back
(
seq_right
);
}
template
<
typename
F
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
transform
(
F
f
)
{
return
sequence
<
f
(
Is
)...
>
{};
}
CK_TILE_HOST_DEVICE
static
void
print
()
{
printf
(
"sequence{size: %d, data: ["
,
size
());
((
printf
(
"%d "
,
Is
)),
...);
printf
(
"]}"
);
}
};
namespace
impl
{
template
<
typename
T
,
T
...
Ints
>
struct
__integer_sequence
;
template
<
index_t
...
Ints
>
struct
__integer_sequence
<
index_t
,
Ints
...
>
{
using
seq_type
=
sequence
<
Ints
...
>
;
};
}
// namespace impl
// similiar
template
<
index_t
N
>
using
make_index_sequence
=
typename
__make_integer_seq
<
impl
::
__integer_sequence
,
index_t
,
N
>::
seq_type
;
// merge sequence
template
<
typename
Seq
,
typename
...
Seqs
>
struct
sequence_merge
{
using
type
=
typename
sequence_merge
<
Seq
,
typename
sequence_merge
<
Seqs
...
>::
type
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_merge
<
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>>
{
using
type
=
sequence
<
Xs
...,
Ys
...
>
;
};
template
<
typename
Seq
>
struct
sequence_merge
<
Seq
>
{
using
type
=
Seq
;
};
// generate sequence
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
{
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
G
>
struct
sequence_gen_impl
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
typename
sequence_merge
<
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
G
>::
type
,
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
G
>::
type
>::
type
;
};
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
1
,
G
>
{
static
constexpr
index_t
Is
=
G
{}(
number
<
I
>
{});
using
type
=
sequence
<
Is
>
;
};
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
0
,
G
>
{
using
type
=
sequence
<>
;
};
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
};
// arithmetic sequence
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
struct
arithmetic_sequence_gen
{
struct
F
{
CK_TILE_HOST_DEVICE
constexpr
index_t
operator
()(
index_t
i
)
const
{
return
i
*
Increment
+
IBegin
;
}
};
using
type0
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type1
=
sequence
<>
;
static
constexpr
bool
kHasContent
=
(
Increment
>
0
&&
IBegin
<
IEnd
)
||
(
Increment
<
0
&&
IBegin
>
IEnd
);
using
type
=
typename
std
::
conditional
<
kHasContent
,
type0
,
type1
>::
type
;
};
template
<
index_t
IEnd
>
struct
arithmetic_sequence_gen
<
0
,
IEnd
,
1
>
{
using
type
=
make_index_sequence
<
IEnd
>
;
};
// uniform sequence
template
<
index_t
NSize
,
index_t
I
>
struct
uniform_sequence_gen
{
struct
F
{
CK_TILE_HOST_DEVICE
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
};
using
type
=
typename
sequence_gen
<
NSize
,
F
>::
type
;
};
// reverse inclusive scan (with init) sequence
template
<
typename
,
typename
,
index_t
>
struct
sequence_reverse_inclusive_scan
;
template
<
index_t
I
,
index_t
...
Is
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
{
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
static
constexpr
index_t
new_reduce
=
Reduce
{}(
I
,
old_scan
{}.
front
());
using
type
=
typename
sequence_merge
<
sequence
<
new_reduce
>
,
old_scan
>::
type
;
};
template
<
index_t
I
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
sequence
<
I
>
,
Reduce
,
Init
>
{
using
type
=
sequence
<
Reduce
{}(
I
,
Init
)
>
;
};
template
<
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
sequence
<>
,
Reduce
,
Init
>
{
using
type
=
sequence
<>
;
};
// split sequence
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
{
static
constexpr
index_t
NSize
=
Seq
{}.
size
();
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
left_type
=
decltype
(
Seq
::
extract
(
range0
{}));
using
right_type
=
decltype
(
Seq
::
extract
(
range1
{}));
};
#if 0
// reverse sequence
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
struct sequence_reverse<sequence<I>>
{
using type = sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<sequence<I0, I1>>
{
using type = sequence<I1, I0>;
};
#endif
namespace
impl
{
template
<
typename
Id
,
index_t
...
Ns
>
struct
seq_reverse
;
template
<
index_t
...
Ids
,
index_t
...
Ns
>
struct
seq_reverse
<
sequence
<
Ids
...
>
,
Ns
...
>
{
template
<
index_t
I
>
using
element
=
impl
::
at_index_t
<
I
,
constant
<
Ns
>
...
>
;
using
type
=
sequence
<
element
<
(
sizeof
...(
Ns
)
-
1
-
Ids
)
>::
value
...
>
;
};
}
// namespace impl
template
<
index_t
...
Ns
>
struct
sequence_reverse
<
sequence
<
Ns
...
>>
:
impl
::
seq_reverse
<
make_index_sequence
<
sizeof
...(
Ns
)
>
,
Ns
...
>
{
};
// template <index_t... Ns>
// using sequence_reverse_t = typename sequence_reverse<Ns...>::type;
#if 1
template
<
typename
Reduce
,
typename
Seq
,
typename
...
Seqs
>
struct
sequence_reduce
{
using
type
=
typename
sequence_reduce
<
Reduce
,
Seq
,
typename
sequence_reduce
<
Reduce
,
Seqs
...
>::
type
>::
type
;
};
template
<
typename
Reduce
,
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_reduce
<
Reduce
,
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>>
{
using
type
=
sequence
<
Reduce
{}(
Xs
,
Ys
)...
>
;
};
template
<
typename
Reduce
,
typename
Seq
>
struct
sequence_reduce
<
Reduce
,
Seq
>
{
using
type
=
Seq
;
};
#endif
template
<
typename
Values
,
typename
Ids
,
typename
Compare
>
struct
sequence_sort_impl
{
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
{
static
constexpr
bool
choose_left
=
LeftValues
::
front
()
<
RightValues
::
front
();
static
constexpr
index_t
chosen_value
=
choose_left
?
LeftValues
::
front
()
:
RightValues
::
front
();
static
constexpr
index_t
chosen_id
=
choose_left
?
LeftIds
::
front
()
:
RightIds
::
front
();
using
new_merged_values
=
decltype
(
MergedValues
::
push_back
(
number
<
chosen_value
>
{}));
using
new_merged_ids
=
decltype
(
MergedIds
::
push_back
(
number
<
chosen_id
>
{}));
using
new_left_values
=
typename
std
::
conditional
<
choose_left
,
decltype
(
LeftValues
::
pop_front
()),
LeftValues
>::
type
;
using
new_left_ids
=
typename
std
::
conditional
<
choose_left
,
decltype
(
LeftIds
::
pop_front
()),
LeftIds
>::
type
;
using
new_right_values
=
typename
std
::
conditional
<
choose_left
,
RightValues
,
decltype
(
RightValues
::
pop_front
())
>::
type
;
using
new_right_ids
=
typename
std
::
conditional
<
choose_left
,
RightIds
,
decltype
(
RightIds
::
pop_front
())
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
new_left_values
,
new_left_ids
,
new_right_values
,
new_right_ids
,
new_merged_values
,
new_merged_ids
,
Comp
>
;
// this is output
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
sequence
<>
,
sequence
<>
,
MergedValues
,
MergedIds
,
Comp
>
{
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
LeftValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
LeftIds
>::
type
;
};
template
<
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
sequence
<>
,
sequence
<>
,
RightValues
,
RightIds
,
MergedValues
,
MergedIds
,
Comp
>
{
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
RightValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
RightIds
>::
type
;
};
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
Comp
>
struct
sorted_sequence_merge
{
using
merge
=
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
RightValues
,
RightIds
,
sequence
<>
,
sequence
<>
,
Comp
>
;
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
static
constexpr
index_t
nsize
=
Values
::
size
();
using
split_unsorted_values
=
sequence_split
<
Values
,
nsize
/
2
>
;
using
split_unsorted_ids
=
sequence_split
<
Ids
,
nsize
/
2
>
;
using
left_unsorted_values
=
typename
split_unsorted_values
::
left_type
;
using
left_unsorted_ids
=
typename
split_unsorted_ids
::
left_type
;
using
left_sort
=
sequence_sort_impl
<
left_unsorted_values
,
left_unsorted_ids
,
Compare
>
;
using
left_sorted_values
=
typename
left_sort
::
sorted_values
;
using
left_sorted_ids
=
typename
left_sort
::
sorted_ids
;
using
right_unsorted_values
=
typename
split_unsorted_values
::
right_type
;
using
right_unsorted_ids
=
typename
split_unsorted_ids
::
right_type
;
using
right_sort
=
sequence_sort_impl
<
right_unsorted_values
,
right_unsorted_ids
,
Compare
>
;
using
right_sorted_values
=
typename
right_sort
::
sorted_values
;
using
right_sorted_ids
=
typename
right_sort
::
sorted_ids
;
using
merged_sorted
=
sorted_sequence_merge
<
left_sorted_values
,
left_sorted_ids
,
right_sorted_values
,
right_sorted_ids
,
Compare
>
;
using
sorted_values
=
typename
merged_sorted
::
merged_values
;
using
sorted_ids
=
typename
merged_sorted
::
merged_ids
;
};
template
<
index_t
ValueX
,
index_t
ValueY
,
index_t
IdX
,
index_t
IdY
,
typename
Compare
>
struct
sequence_sort_impl
<
sequence
<
ValueX
,
ValueY
>
,
sequence
<
IdX
,
IdY
>
,
Compare
>
{
static
constexpr
bool
choose_x
=
Compare
{}(
ValueX
,
ValueY
);
using
sorted_values
=
typename
std
::
conditional
<
choose_x
,
sequence
<
ValueX
,
ValueY
>
,
sequence
<
ValueY
,
ValueX
>>::
type
;
using
sorted_ids
=
typename
std
::
conditional
<
choose_x
,
sequence
<
IdX
,
IdY
>
,
sequence
<
IdY
,
IdX
>>::
type
;
};
template
<
index_t
Value
,
index_t
Id
,
typename
Compare
>
struct
sequence_sort_impl
<
sequence
<
Value
>
,
sequence
<
Id
>
,
Compare
>
{
using
sorted_values
=
sequence
<
Value
>
;
using
sorted_ids
=
sequence
<
Id
>
;
};
template
<
typename
Compare
>
struct
sequence_sort_impl
<
sequence
<>
,
sequence
<>
,
Compare
>
{
using
sorted_values
=
sequence
<>
;
using
sorted_ids
=
sequence
<>
;
};
template
<
typename
Values
,
typename
Compare
>
struct
sequence_sort
{
using
unsorted_ids
=
typename
arithmetic_sequence_gen
<
0
,
Values
::
size
(),
1
>::
type
;
using
sort
=
sequence_sort_impl
<
Values
,
unsorted_ids
,
Compare
>
;
// this is output
using
type
=
typename
sort
::
sorted_values
;
using
sorted2unsorted_map
=
typename
sort
::
sorted_ids
;
};
template
<
typename
Values
,
typename
Less
,
typename
Equal
>
struct
sequence_unique_sort
{
template
<
typename
RemainValues
,
typename
RemainIds
,
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
{
static
constexpr
index_t
current_value
=
RemainValues
::
front
();
static
constexpr
index_t
current_id
=
RemainIds
::
front
();
static
constexpr
bool
is_unique_value
=
(
current_value
!=
UniquifiedValues
::
back
());
using
new_remain_values
=
decltype
(
RemainValues
::
pop_front
());
using
new_remain_ids
=
decltype
(
RemainIds
::
pop_front
());
using
new_uniquified_values
=
typename
std
::
conditional
<
is_unique_value
,
decltype
(
UniquifiedValues
::
push_back
(
number
<
current_value
>
{})),
UniquifiedValues
>::
type
;
using
new_uniquified_ids
=
typename
std
::
conditional
<
is_unique_value
,
decltype
(
UniquifiedIds
::
push_back
(
number
<
current_id
>
{})),
UniquifiedIds
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
new_remain_values
,
new_remain_ids
,
new_uniquified_values
,
new_uniquified_ids
,
Eq
>
;
// this is output
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
template
<
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
<
sequence
<>
,
sequence
<>
,
UniquifiedValues
,
UniquifiedIds
,
Eq
>
{
using
uniquified_values
=
UniquifiedValues
;
using
uniquified_ids
=
UniquifiedIds
;
};
template
<
typename
SortedValues
,
typename
SortedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify
{
using
uniquify
=
sorted_sequence_uniquify_impl
<
decltype
(
SortedValues
::
pop_front
()),
decltype
(
SortedIds
::
pop_front
()),
sequence
<
SortedValues
::
front
()
>
,
sequence
<
SortedIds
::
front
()
>
,
Eq
>
;
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
using
sort
=
sequence_sort
<
Values
,
Less
>
;
using
sorted_values
=
typename
sort
::
type
;
using
sorted_ids
=
typename
sort
::
sorted2unsorted_map
;
using
uniquify
=
sorted_sequence_uniquify
<
sorted_values
,
sorted_ids
,
Equal
>
;
// this is output
using
type
=
typename
uniquify
::
uniquified_values
;
using
sorted2unsorted_map
=
typename
uniquify
::
uniquified_ids
;
};
template
<
typename
SeqMap
>
struct
is_valid_sequence_map
:
std
::
is_same
<
typename
arithmetic_sequence_gen
<
0
,
SeqMap
::
size
(),
1
>::
type
,
typename
sequence_sort
<
SeqMap
,
less
<
index_t
>>::
type
>
{
};
template
<
typename
SeqMap
>
struct
sequence_map_inverse
{
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
struct
sequence_map_inverse_impl
{
static
constexpr
auto
new_y2x
=
WorkingY2X
::
modify
(
X2Y
::
get
(
number
<
XBegin
>
{}),
number
<
XBegin
>
{});
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
using
type
=
WorkingY2X
;
};
using
type
=
typename
sequence_map_inverse_impl
<
SeqMap
,
typename
uniform_sequence_gen
<
SeqMap
::
size
(),
0
>::
type
,
0
,
SeqMap
::
size
()
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
==
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
return
((
Xs
==
Ys
)
&&
...);
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
sequence
<
Xs
...
>
x
,
sequence
<
Ys
...
>
y
)
{
return
!
(
x
==
y
);
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
+
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
-
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
*
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
/
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
%
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
%
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
+
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
-
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
*
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
/
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
%
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
%
Y
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
+
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
-
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
*
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
/
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
%
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
%
Xs
)...
>
{};
}
template
<
index_t
I
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_front
(
sequence
<
I
,
Is
...
>
)
{
return
sequence
<
Is
...
>
{};
}
template
<
typename
Seq
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_back
(
Seq
)
{
static_assert
(
Seq
::
size
()
>
0
,
"wrong! cannot pop an empty sequence!"
);
return
sequence_pop_front
(
Seq
::
reverse
()).
reverse
();
}
template
<
typename
...
Seqs
>
CK_TILE_HOST_DEVICE
constexpr
auto
merge_sequences
(
Seqs
...)
{
return
typename
sequence_merge
<
Seqs
...
>::
type
{};
}
template
<
typename
F
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_sequences
(
F
f
,
sequence
<
Xs
...
>
)
{
return
sequence
<
f
(
Xs
)...
>
{};
}
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_sequences
(
F
f
,
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sequence
<
Xs
...
>::
size
()
==
sequence
<
Ys
...
>::
size
(),
"Dim not the same"
);
return
sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_sequences
(
F
f
,
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
,
sequence
<
Zs
...
>
)
{
static_assert
(
sequence
<
Xs
...
>::
size
()
==
sequence
<
Ys
...
>::
size
()
&&
sequence
<
Xs
...
>::
size
()
==
sequence
<
Zs
...
>::
size
(),
"Dim not the same"
);
return
sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
type
{};
}
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
reverse_exclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
return
reverse_inclusive_scan_sequence
(
Seq
::
pop_front
(),
Reduce
{},
number
<
Init
>
{})
.
push_back
(
number
<
Init
>
{});
}
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
inclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
return
reverse_inclusive_scan_sequence
(
Seq
{}.
reverse
(),
Reduce
{},
number
<
Init
>
{}).
reverse
();
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add
// ResultSeq TargetSeq Reduce
template
<
typename
,
typename
,
typename
>
struct
sequence_exclusive_scan
;
template
<
index_t
...
Xs
,
index_t
Y
,
index_t
...
Ys
,
typename
Reduce
>
struct
sequence_exclusive_scan
<
sequence
<
Xs
...
>
,
sequence
<
Y
,
Ys
...
>
,
Reduce
>
{
using
old_scan
=
typename
sequence_merge
<
sequence
<
Xs
...
>
,
sequence
<
Reduce
{}(
Y
,
sequence
<
Xs
...
>
{}.
back
())
>>::
type
;
using
type
=
typename
sequence_exclusive_scan
<
old_scan
,
sequence
<
Ys
...
>
,
Reduce
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
Y
,
typename
Reduce
>
struct
sequence_exclusive_scan
<
sequence
<
Xs
...
>
,
sequence
<
Y
>
,
Reduce
>
{
using
type
=
sequence
<
Xs
...
>
;
};
template
<
index_t
...
Xs
,
typename
Reduce
>
struct
sequence_exclusive_scan
<
sequence
<
Xs
...
>
,
sequence
<>
,
Reduce
>
{
using
type
=
sequence
<
Xs
...
>
;
};
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
constexpr
auto
exclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
// TODO: c++20 and later can pass in Reduce with a lambda expression
return
typename
sequence_exclusive_scan
<
sequence
<
Init
>
,
Seq
,
Reduce
>::
type
{};
}
template
<
typename
Seq
>
constexpr
auto
prefix_sum_sequence
(
Seq
)
{
return
typename
sequence_exclusive_scan
<
sequence
<
0
>
,
typename
sequence_merge
<
Seq
,
sequence
<
0
>>::
type
,
plus
<
index_t
>>::
type
{};
}
template
<
typename
Seq
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
pick_sequence_elements_by_ids
(
Seq
,
sequence
<
Is
...
>
/* ids */
)
{
return
sequence
<
Seq
::
get
(
number
<
Is
>
{})...
>
{};
}
#if 1
namespace
detail
{
template
<
typename
WorkSeq
,
typename
RemainSeq
,
typename
RemainMask
>
struct
pick_sequence_elements_by_mask_impl
{
using
new_work_seq
=
typename
std
::
conditional
<
RemainMask
::
front
(),
decltype
(
WorkSeq
::
push_back
(
RemainSeq
::
front
())),
WorkSeq
>::
type
;
using
type
=
typename
pick_sequence_elements_by_mask_impl
<
new_work_seq
,
decltype
(
RemainSeq
::
pop_front
()),
decltype
(
RemainMask
::
pop_front
())
>::
type
;
};
template
<
typename
WorkSeq
>
struct
pick_sequence_elements_by_mask_impl
<
WorkSeq
,
sequence
<>
,
sequence
<>>
{
using
type
=
WorkSeq
;
};
}
// namespace detail
template
<
typename
Seq
,
typename
Mask
>
CK_TILE_HOST_DEVICE
constexpr
auto
pick_sequence_elements_by_mask
(
Seq
,
Mask
)
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
(),
"wrong!"
);
return
typename
detail
::
pick_sequence_elements_by_mask_impl
<
sequence
<>
,
Seq
,
Mask
>::
type
{};
}
namespace
detail
{
template
<
typename
WorkSeq
,
typename
RemainValues
,
typename
RemainIds
>
struct
modify_sequence_elements_by_ids_impl
{
using
new_work_seq
=
decltype
(
WorkSeq
::
modify
(
RemainIds
::
front
(),
RemainValues
::
front
()));
using
type
=
typename
modify_sequence_elements_by_ids_impl
<
new_work_seq
,
decltype
(
RemainValues
::
pop_front
()),
decltype
(
RemainIds
::
pop_front
())
>::
type
;
};
template
<
typename
WorkSeq
>
struct
modify_sequence_elements_by_ids_impl
<
WorkSeq
,
sequence
<>
,
sequence
<>>
{
using
type
=
WorkSeq
;
};
}
// namespace detail
template
<
typename
Seq
,
typename
Values
,
typename
Ids
>
CK_TILE_HOST_DEVICE
constexpr
auto
modify_sequence_elements_by_ids
(
Seq
,
Values
,
Ids
)
{
static_assert
(
Values
::
size
()
==
Ids
::
size
()
&&
Seq
::
size
()
>=
Values
::
size
(),
"wrong!"
);
return
typename
detail
::
modify_sequence_elements_by_ids_impl
<
Seq
,
Values
,
Ids
>::
type
{};
}
#endif
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
index_t
reduce_on_sequence
(
Seq
,
Reduce
f
,
number
<
Init
>
/*initial_value*/
)
{
index_t
result
=
Init
;
for
(
index_t
i
=
0
;
i
<
Seq
::
size
();
++
i
)
{
result
=
f
(
result
,
Seq
::
at
(
i
));
}
return
result
;
}
// TODO: a generic any_of for any container
template
<
typename
Seq
,
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
bool
sequence_any_of
(
Seq
,
F
f
)
{
bool
flag
=
false
;
for
(
index_t
i
=
0
;
i
<
Seq
::
size
();
++
i
)
{
flag
=
flag
||
f
(
Seq
::
at
(
i
));
}
return
flag
;
}
// TODO: a generic all_of for any container
template
<
typename
Seq
,
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
bool
sequence_all_of
(
Seq
,
F
f
)
{
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
Seq
::
size
();
++
i
)
{
flag
=
flag
&&
f
(
Seq
::
at
(
i
));
}
return
flag
;
}
template
<
typename
...
Seqs
>
using
sequence_merge_t
=
typename
sequence_merge
<
Seqs
...
>::
type
;
template
<
index_t
NSize
,
index_t
I
>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_sequence
(
number
<
Is
>
...)
{
return
sequence
<
Is
...
>
{};
}
// F() returns index_t
// F use default constructor, so F cannot be lambda function
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_sequence
(
F
,
number
<
N
>
)
{
return
typename
sequence_gen
<
N
,
F
>::
type
{};
}
// F() returns number<>
// F could be lambda function
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_sequence_v2
(
F
&&
f
,
number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
xs
)
{
return
make_sequence
(
f
(
xs
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
template
<
class
...
T
>
struct
tuple
;
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_sequence
(
tuple
<
number
<
Is
>
...
>
)
{
return
sequence
<
Is
...
>
{};
}
namespace
detail
{
template
<
index_t
h_idx
,
typename
SeqSortedSamples
,
typename
SeqRange
>
struct
sorted_sequence_histogram
;
template
<
index_t
h_idx
,
index_t
x
,
index_t
...
xs
,
index_t
r
,
index_t
...
rs
>
struct
sorted_sequence_histogram
<
h_idx
,
sequence
<
x
,
xs
...
>
,
sequence
<
r
,
rs
...
>>
{
template
<
typename
Histogram
>
constexpr
auto
operator
()(
Histogram
&
h
)
{
if
constexpr
(
x
<
r
)
{
h
.
template
at
<
h_idx
>()
+=
1
;
sorted_sequence_histogram
<
h_idx
,
sequence
<
xs
...
>
,
sequence
<
r
,
rs
...
>>
{}(
h
);
}
else
{
h
.
template
at
<
h_idx
+
1
>()
=
1
;
sorted_sequence_histogram
<
h_idx
+
1
,
sequence
<
xs
...
>
,
sequence
<
rs
...
>>
{}(
h
);
}
}
};
template
<
index_t
h_idx
,
index_t
x
,
index_t
r
,
index_t
...
rs
>
struct
sorted_sequence_histogram
<
h_idx
,
sequence
<
x
>
,
sequence
<
r
,
rs
...
>>
{
template
<
typename
Histogram
>
constexpr
auto
operator
()(
Histogram
&
h
)
{
if
constexpr
(
x
<
r
)
{
h
.
template
at
<
h_idx
>()
+=
1
;
}
}
};
}
// namespace detail
template
<
typename
,
index_t
>
struct
array
;
// declare for later use (array->seq utility)
// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1>
template
<
typename
SeqSortedSamples
,
index_t
r
,
index_t
...
rs
>
CK_TILE_HOST_DEVICE
constexpr
auto
histogram_sorted_sequence
(
SeqSortedSamples
,
sequence
<
r
,
rs
...
>
)
{
constexpr
auto
bins
=
sizeof
...(
rs
);
// or categories
constexpr
auto
histogram
=
[
&
]()
{
array
<
index_t
,
bins
>
h
{
0
};
// make sure this can clear all element to zero
detail
::
sorted_sequence_histogram
<
0
,
SeqSortedSamples
,
sequence
<
rs
...
>>
{}(
h
);
return
h
;
}();
return
TO_SEQUENCE
(
histogram
,
bins
);
}
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_array
(
F
&&
f
,
number
<
N
>
)
{
using
T
=
remove_cvref_t
<
decltype
(
f
(
number
<
0
>
{}))
>
;
return
unpack
([
&
f
](
auto
&&
...
is
)
{
return
array
<
T
,
N
>
{
f
(
is
)...};
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
// namespace ck_tile
include/ck_tile/core/container/span.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <cstddef>
#include <array>
#include <type_traits>
namespace
ck_tile
{
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
template
<
typename
T
>
class
span
{
public:
using
element_type
=
T
;
using
value_type
=
std
::
remove_cv_t
<
element_type
>
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
pointer
=
element_type
*
;
using
const_pointer
=
const
element_type
*
;
using
reference
=
element_type
&
;
using
const_reference
=
const
element_type
&
;
using
iterator
=
pointer
;
using
const_iterator
=
pointer
;
CK_TILE_HOST_DEVICE
constexpr
span
()
:
span
(
nullptr
,
size_type
{
0
})
{}
CK_TILE_HOST_DEVICE
constexpr
span
(
pointer
first
,
size_type
count
)
:
ptr_
(
first
),
size_
(
count
)
{
}
CK_TILE_HOST_DEVICE
constexpr
span
(
pointer
first
,
pointer
last
)
:
span
(
first
,
last
-
first
)
{}
template
<
std
::
size_t
N
>
CK_TILE_HOST_DEVICE
constexpr
span
(
element_type
(
&
arr
)[
N
])
noexcept
:
span
(
arr
,
N
)
{
}
template
<
std
::
size_t
N
>
CK_TILE_HOST_DEVICE
constexpr
span
(
std
::
array
<
value_type
,
N
>&
arr
)
noexcept
:
span
(
arr
.
data
(),
N
)
{
}
template
<
typename
Container
>
CK_TILE_HOST_DEVICE
constexpr
span
(
const
Container
&
container
)
:
span
(
container
.
data
(),
container
.
size
())
{
}
CK_TILE_HOST_DEVICE
constexpr
iterator
begin
()
const
noexcept
{
return
ptr_
;
}
CK_TILE_HOST_DEVICE
constexpr
const_iterator
cbegin
()
const
noexcept
{
return
begin
();
}
CK_TILE_HOST_DEVICE
constexpr
iterator
end
()
const
noexcept
{
return
begin
()
+
size
();
}
CK_TILE_HOST_DEVICE
constexpr
const_iterator
cend
()
const
noexcept
{
return
end
();
}
CK_TILE_HOST_DEVICE
constexpr
reference
front
()
const
{
return
*
begin
();
}
CK_TILE_HOST_DEVICE
constexpr
reference
back
()
const
{
return
*
(
--
end
());
}
CK_TILE_HOST_DEVICE
constexpr
reference
operator
[](
size_type
idx
)
const
{
return
*
(
begin
()
+
idx
);
}
CK_TILE_HOST_DEVICE
constexpr
pointer
data
()
const
noexcept
{
return
ptr_
;
}
CK_TILE_HOST_DEVICE
constexpr
size_type
size
()
const
noexcept
{
return
size_
;
}
private:
pointer
ptr_
;
size_type
size_
;
};
}
// namespace ck_tile
include/ck_tile/core/container/statically_indexed_array.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace
ck_tile
{
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
template
<
typename
T
,
index_t
N
>
using
statically_indexed_array
=
tuple_array
<
T
,
N
>
;
#else
// consider mark this struct as deprecated
template
<
typename
T
,
index_t
N
>
using
statically_indexed_array
=
array
<
T
,
N
>
;
#endif
// consider always use ck_tile::array for this purpose
#if 0
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
#endif
}
// namespace ck_tile
include/ck_tile/core/container/thread_buffer.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace
ck_tile
{
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
template
<
typename
T
,
index_t
N
>
using
thread_buffer
=
tuple_array
<
T
,
N
>
;
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_thread_buffer
(
Ts
&&
...
ts
)
{
return
make_tuple
(
ts
...);
}
#else
#if 0
template <typename T, index_t N>
using thread_buffer = array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
// clang-format off
template
<
typename
T_
,
index_t
N_
>
struct
thread_buffer
{
using
value_type
=
remove_cvref_t
<
T_
>
;
static
constexpr
index_t
N
=
N_
;
value_type
data
[
N
];
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE
constexpr
thread_buffer
()
:
data
{}
{}
CK_TILE_HOST_DEVICE
constexpr
thread_buffer
(
const
value_type
&
o
)
:
data
{
o
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
size
()
{
return
N
;
}
CK_TILE_HOST_DEVICE
auto
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
const
auto
&
get
()
const
{
return
data
;
}
CK_TILE_HOST_DEVICE
auto
&
get
(
index_t
i
)
{
return
data
[
i
];
}
CK_TILE_HOST_DEVICE
const
auto
&
get
(
index_t
i
)
const
{
return
data
[
i
];
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
operator
[](
index_t
i
)
const
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
operator
[](
index_t
i
)
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
operator
()(
index_t
i
)
{
return
get
(
i
);
}
// TODO: compatible
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
index_t
i
)
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
index_t
i
)
const
{
return
get
(
i
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
()
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
typename
X_
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
{
using
X
=
remove_cvref_t
<
X_
>
;
constexpr
index_t
kSPerX
=
vector_traits
<
X
>::
vector_size
;
static_assert
(
N
%
kSPerX
==
0
);
union
{
thread_buffer
<
X_
,
N
/
kSPerX
>
data
{};
// tuple_array<value_type, kSPerX> sub_data;
value_type
sub_data
[
N
];
}
vx
;
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
j
)
{
vx
.
sub_data
[
j
]
=
data
[
j
];
});
return
vx
.
data
;
}
template
<
typename
X_
,
index_t
Is
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
const
constexpr
remove_reference_t
<
X_
>
_get_as
(
number
<
Is
>
is
)
const
{
using
X
=
remove_cvref_t
<
X_
>
;
constexpr
index_t
kSPerX
=
vector_traits
<
X
>::
vector_size
;
union
{
X_
data
{};
tuple_array
<
value_type
,
kSPerX
>
sub_data
;
}
vx
;
static_for
<
0
,
kSPerX
,
1
>
{}(
[
&
](
auto
j
)
{
vx
.
sub_data
(
j
)
=
operator
[]((
is
*
number
<
sizeof
(
X_
)
/
sizeof
(
value_type
)
>
{})
+
j
);
});
return
vx
.
data
;
}
#if 0
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data;
tuple_array<value_type, kSPerX> sub_data;
} vx {x};
static_for<0, kSPerX, 1>{}(
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
}
#endif
#define TB_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
auto
&
get_as
()
{
TB_COMMON_AS
();
return
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
);}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_as
()
const
{
TB_COMMON_AS
();
if
constexpr
(
sizeof
(
value_type
)
<=
1
)
return
_get_as
<
Tx
>
();
// TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return
reinterpret_cast
<
const
thread_buffer
<
Tx
,
vx
>&>
(
data
);}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
auto
&
get_as
(
number
<
I
>
)
{
TB_COMMON_AS
();
return
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
).
get
(
number
<
I
>
{});}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_as
(
number
<
I
>
)
const
{
TB_COMMON_AS
();
if
constexpr
(
sizeof
(
value_type
)
<=
1
)
return
_get_as
<
Tx
>
(
number
<
I
>
{});
// TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return
reinterpret_cast
<
const
thread_buffer
<
Tx
,
vx
>&>
(
data
).
get
(
number
<
I
>
{});}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
index_t
i
,
const
Tx
&
x
)
{
TB_COMMON_AS
();
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
).
at
(
i
)
=
x
;
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
number
<
I
>
,
const
Tx
&
x
)
{
TB_COMMON_AS
();
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
).
at
(
number
<
I
>
{})
=
x
;
}
#undef TB_COMMON_AS
};
// clang-format on
template
<
typename
>
struct
vector_traits
;
// specialization for array
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
thread_buffer
<
T
,
N
>>
{
using
scalar_type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
#endif
}
// namespace ck_tile
include/ck_tile/core/container/tuple.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
#include <initializer_list>
#ifndef CK_TILE_TUPLE_IMPL
#define CK_TILE_TUPLE_IMPL 1
#endif
namespace
ck_tile
{
namespace
impl
{
template
<
typename
T
,
index_t
N
>
struct
tuple_array_impl
;
}
template
<
typename
T
,
index_t
N
>
using
tuple_array
=
typename
impl
::
tuple_array_impl
<
T
,
N
>::
type
;
namespace
impl
{
// the place where content is stored
template
<
index_t
idx
,
typename
T
,
bool
is_empty
=
std
::
is_empty_v
<
T
>
>
struct
tuple_object
{
};
template
<
index_t
idx
,
typename
T
>
struct
tuple_object
<
idx
,
T
,
true
>
{
CK_TILE_HOST_DEVICE
constexpr
tuple_object
()
{}
#if CK_TILE_TUPLE_IMPL == 0
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
)
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
const
U
&
)
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&
)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
typename
U
,
typename
std
::
enable_if
<!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple_object
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
)
{
}
#endif
};
template
<
index_t
idx
,
typename
T
>
struct
tuple_object
<
idx
,
T
,
false
>
{
CK_TILE_HOST_DEVICE
constexpr
tuple_object
()
:
element
{}
{}
#if CK_TILE_TUPLE_IMPL == 0
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
e
)
:
element
(
std
::
forward
<
U
>
(
e
))
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
const
U
&
e
)
:
element
(
e
)
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&
e
)
:
element
(
e
)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
typename
U
,
typename
std
::
enable_if
<!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple_object
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
e
)
:
element
(
std
::
forward
<
U
>
(
e
))
{
}
#endif
T
element
;
};
// NOTE: we return a instance(not a reference) if content is empty
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
T
getv
(
const
tuple_object
<
I
,
T
,
true
>&
)
{
return
{};
}
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
const
T
&
getv
(
const
tuple_object
<
I
,
T
,
false
>&
x
)
{
return
x
.
element
;
}
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
T
&
getv
(
tuple_object
<
I
,
T
,
false
>&
x
)
{
return
x
.
element
;
}
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
T
&&
getv
(
tuple_object
<
I
,
T
,
false
>&&
x
)
{
return
static_cast
<
T
&&>
(
x
.
element
);
}
template
<
typename
index_seq
,
typename
...
T
>
struct
tuple_base
;
template
<
index_t
...
I
,
typename
...
T
>
struct
tuple_base
<
sequence
<
I
...
>
,
T
...
>
:
tuple_object
<
I
,
T
>
...
{
CK_TILE_HOST_DEVICE
constexpr
tuple_base
()
=
default
;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#define _ILE() (std::initializer_list<U>{}.size() - 1)
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
std
::
initializer_list
<
U
>
us
)
:
tuple_object
<
I
,
T
>
(
static_cast
<
T
>
(
*
(
us
.
begin
()
+
(
I
>=
_ILE
()
?
_ILE
()
:
I
))))...
{
}
#undef _ILE
#endif
#if CK_TILE_TUPLE_IMPL == 0
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
explicit
tuple_base
(
U
&&
...
u
)
:
tuple_object
<
I
,
T
>
(
std
::
forward
<
U
>
(
u
))...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
explicit
tuple_base
(
const
U
&
...
u
)
:
tuple_object
<
I
,
T
>
(
u
)...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
explicit
tuple_base
(
U
&
...
u
)
:
tuple_object
<
I
,
T
>
(
u
)...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
tuple_base
<
sequence
<
I
...
>
,
U
...
>&&
u
)
:
tuple_object
<
I
,
T
>
(
getv
(
static_cast
<
tuple_object
<
I
,
U
>&&>
(
u
)))...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
const
tuple_base
<
sequence
<
I
...
>
,
U
...
>&
u
)
:
tuple_object
<
I
,
T
>
(
getv
(
static_cast
<
const
tuple_object
<
I
,
U
>&>
(
u
)))...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
tuple_base
<
sequence
<
I
...
>
,
U
...
>&
u
)
:
tuple_object
<
I
,
T
>
(
getv
(
static_cast
<
tuple_object
<
I
,
U
>&>
(
u
)))...
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
class
U
,
typename
std
::
enable_if
<
sizeof
...(
I
)
==
1
&&
sizeof
...(
T
)
==
1
&&
!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple_base
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
U
&&
u
)
:
tuple_object
<
I
,
T
>
(
std
::
forward
<
U
>
(
u
))...
{
}
template
<
typename
...
U
,
typename
std
::
enable_if
<
sizeof
...(
U
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
U
&&
...
u
)
:
tuple_object
<
I
,
T
>
(
std
::
forward
<
U
>
(
u
))...
{
static_assert
(
sizeof
...(
I
)
==
sizeof
...(
T
)
&&
sizeof
...(
I
)
==
sizeof
...(
U
),
"wrong! inconsistent size"
);
}
#endif
};
}
// namespace impl
template
<
class
...
T
>
struct
tuple
:
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
size
()
{
return
sizeof
...(
T
);
}
using
base
=
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
;
CK_TILE_HOST_DEVICE
constexpr
tuple
()
=
default
;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
std
::
initializer_list
<
U
>
us
)
:
base
(
us
)
{
}
#endif
#if CK_TILE_TUPLE_IMPL == 0
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&&
...
u
)
:
base
(
std
::
forward
<
U
>
(
u
)...)
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
const
U
&
...
u
)
:
base
(
u
...)
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&
...
u
)
:
base
(
u
...)
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
tuple
<
U
...
>&&
u
)
:
base
(
static_cast
<
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
U
)
>
,
U
...
>&&>
(
u
))
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
const
tuple
<
U
...
>&
u
)
:
base
(
static_cast
<
const
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
U
)
>
,
U
...
>&>
(
u
))
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
tuple
<
U
...
>&
u
)
:
base
(
static_cast
<
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
U
)
>
,
U
...
>&>
(
u
))
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
typename
U
,
typename
std
::
enable_if
<
sizeof
...(
T
)
==
1
&&
!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&&
u
)
:
base
(
std
::
forward
<
U
>
(
u
))
{
}
template
<
typename
...
U
,
typename
std
::
enable_if
<
sizeof
...(
U
)
==
sizeof
...(
T
)
&&
sizeof
...(
U
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&&
...
u
)
:
base
(
std
::
forward
<
U
>
(
u
)...)
{
}
#endif
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
bool
flag
=
true
;
static_for
<
0
,
sizeof
...(
T
),
1
>
{}([
&
flag
](
auto
i
)
{
flag
&=
is_static_v
<
remove_cvref_t
<
__type_pack_element
<
i
.
value
,
T
...
>>>
;
});
return
flag
;
}
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
()
const
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
(
number
<
I
>
)
const
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
()
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
(
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
()
const
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
(
number
<
I
>
)
const
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
()
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
(
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
operator
[](
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
operator
[](
number
<
I
>
)
const
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
operator
()(
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
// TODO: compatible
// below function should be used under tuple_array<> type, no extra check will perform here
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
()
{
return
reinterpret_cast
<
tuple_array
<
Tx
,
size
()
>&>
(
*
this
);
}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
()
const
{
return
reinterpret_cast
<
const
tuple_array
<
Tx
,
size
()
>&>
(
*
this
);
}
// below index is for index *AFTER* type convert, not before
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
(
number
<
I
>
)
{
TP_COM_
();
return
reinterpret_cast
<
tuple_array
<
Tx
,
size
()
>&>
(
*
this
).
at
(
number
<
I
>
{});
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
(
number
<
I
>
)
const
{
TP_COM_
();
return
reinterpret_cast
<
const
tuple_array
<
Tx
,
size
()
>&>
(
*
this
).
at
(
number
<
I
>
{});
}
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
number
<
I
>
,
const
Tx
&
x
)
{
TP_COM_
();
reinterpret_cast
<
tuple_array
<
Tx
,
size
()
>&>
(
*
this
).
at
(
number
<
I
>
{})
=
x
;
}
// clang-format on
#undef TP_COM_
};
template
<
typename
>
struct
vector_traits
;
// specialization for array
template
<
typename
...
T
>
struct
vector_traits
<
tuple
<
T
...
>>
{
using
scalar_type
=
__type_pack_element
<
0
,
T
...
>
;
static
constexpr
index_t
vector_size
=
sizeof
...(
T
);
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
==
(
const
tuple
<
Xs
...
>&
a
,
const
tuple
<
Xs
...
>&
b
)
{
bool
same
=
true
;
static_for
<
0
,
sizeof
...(
Xs
),
1
>
{}([
&
](
auto
i
)
{
if
(
a
[
i
]
!=
b
[
i
])
{
same
=
false
;
}
});
return
same
;
}
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
const
tuple
<
Xs
...
>&
a
,
const
tuple
<
Xs
...
>&
b
)
{
return
!
(
a
==
b
);
}
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
// here xs is always a lvalue as function arg
// Xs may deduced as (e.g try to pass in a integer in following cases)
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
// 2). if pass in a const lvalue -> Xs is "const int &"
// 3). if pass in a non-const lvalue -> Xs is "int &"
// so the return type of std::forward will dependes on Xs
// 1). std::forward -> int&&
// 2). std::forward -> const int&
// 3). std::forward -> int&
return
tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template
<
typename
...
Args
>
constexpr
tuple
<
Args
&
...
>
tie
(
Args
&
...
args
)
noexcept
{
return
{
args
...};
}
template
<
typename
X
,
typename
Y
>
struct
tuple_concat
;
template
<
typename
...
Xs
,
typename
...
Ys
>
struct
tuple_concat
<
tuple
<
Xs
...
>
,
tuple
<
Ys
...
>>
{
using
type
=
tuple
<
Xs
...,
Ys
...
>
;
};
namespace
impl
{
// be very careful using this type (because we want the internal type)
// template deduction will fail if infering the inner type
// e.g.
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
// -> compiler will fail to deduce this type, because this is under non-deduced context
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
// contexts")
//
// -> use this instead
// template<typename Tup> void foo(const Tup&) {}
template
<
typename
T
,
index_t
N
>
struct
tuple_array_impl
{
using
type
=
typename
tuple_concat
<
typename
tuple_array_impl
<
T
,
N
/
2
>::
type
,
typename
tuple_array_impl
<
T
,
N
-
N
/
2
>::
type
>::
type
;
};
template
<
typename
T
>
struct
tuple_array_impl
<
T
,
0
>
{
using
type
=
tuple
<>
;
};
template
<
typename
T
>
struct
tuple_array_impl
<
T
,
1
>
{
using
type
=
tuple
<
T
>
;
};
}
// namespace impl
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_tuple
(
F
&&
f
,
number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
is
)
{
return
make_tuple
(
f
(
is
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_tie
(
F
&&
f
,
number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
is
)
{
return
tie
(
f
(
is
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template
<
typename
...
X
,
typename
...
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple_of_reference
(
const
tuple
<
X
&
...
>&
tx
,
const
tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
template
<
typename
...
X
,
typename
...
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple
(
const
tuple
<
X
...
>&
tx
,
const
tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
// Support any number of tuples to concat (also 1)
template
<
typename
...
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple
(
const
tuple
<
X
...
>&
tx
)
{
return
tx
;
}
template
<
typename
...
X
,
typename
...
Tuples
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple
(
const
tuple
<
X
...
>&
tx
,
const
Tuples
&
...
tuples
)
{
return
concat_tuple
(
tx
,
concat_tuple
(
tuples
...));
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
template
<
typename
F
,
typename
X
,
typename
Y
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}),
y
.
at
(
number
<
Is
>
{}))...);
}
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
const
Z
&
z
,
sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}),
y
.
at
(
number
<
Is
>
{}),
z
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
template
<
typename
F
,
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
,
const
Y
&
y
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
y
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
const
Z
&
z
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
{
return
t
;
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
T
&
t
)
{
return
make_tuple
(
t
);
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<
Ts
...
>&
t
)
{
if
constexpr
(
Depth
==
MaxDepth
)
{
return
t
;
}
else
{
return
unpack
(
[
&
](
auto
&&
...
ts
)
{
return
concat_tuple
(
unroll_nested_tuple
<
Depth
+
1
,
MaxDepth
>
(
ts
)...);
},
t
);
}
}
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_reverse
(
const
tuple
<
Ts
...
>&
t
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
number
<
tuple
<
Ts
...
>::
size
()
-
i
-
1
>
;
return
t
.
at
(
Idx
{});
},
number
<
tuple
<
Ts
...
>::
size
()()
>
{});
}
// Reduce tuple values in specific range using Function
template
<
index_t
Idx
,
index_t
End
,
typename
F
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_reduce
(
F
&&
f
,
const
tuple
<
Ts
...
>&
t
)
{
static_assert
(
Idx
<
End
,
"Wrong parameters for tuple_reduce"
);
if
constexpr
(
Idx
+
1
==
End
)
{
return
t
.
at
(
number
<
Idx
>
{});
}
else
{
return
f
(
t
.
at
(
number
<
Idx
>
{}),
tuple_reduce
<
Idx
+
1
,
End
>
(
f
,
t
));
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
is_nested_tuple
(
const
tuple
<
Ts
...
>&
)
{
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
template
<
index_t
depth
=
0
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_depth
(
const
T
&
)
{
return
depth
;
}
template
<
index_t
depth
=
0
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_depth
(
const
tuple
<
Ts
...
>&
)
{
return
max
(
tuple_depth
<
depth
+
1
>
(
Ts
{})...);
}
template
<
typename
...
Seqs
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array_of_array
(
tuple
<
Seqs
...
>
t_of_s
)
{
constexpr
index_t
n0
=
sizeof
...(
Seqs
);
constexpr
index_t
max_n1
=
[
&
]
{
index_t
max_n1_
=
0
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i0
)
{
constexpr
index_t
n1
=
t_of_s
[
i0
].
size
();
max_n1_
=
max_n1_
<
n1
?
n1
:
max_n1_
;
});
return
max_n1_
;
}();
array
<
array
<
index_t
,
max_n1
>
,
n0
>
a_of_a
{{
-
1
}};
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i0
)
{
constexpr
index_t
n1
=
t_of_s
[
i0
].
size
();
static_for
<
0
,
n1
,
1
>
{}([
&
](
auto
i1
)
{
a_of_a
(
i0
)(
i1
)
=
t_of_s
[
i0
][
i1
];
});
});
return
a_of_a
;
}
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template
<
typename
...
Ys
,
typename
X
,
std
::
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
return
y
;
}
template
<
typename
...
Ys
,
typename
X
,
std
::
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
return
y
;
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
x
[
i
]
+
y
[
i
];
});
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
x
[
i
]
-
y
[
i
];
});
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
x
[
i
]
*
y
[
i
];
});
return
r
;
}
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
Y
a
,
const
tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
a
*
x
[
i
];
});
return
r
;
}
// MultiIndex = MultiIndex * scalar
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
Y
a
)
{
return
a
*
x
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
/
y
[
i
];
},
number
<
NSize
>
{});
}
}
// namespace ck_tile
#include <tuple>
// WARNING: needed by compiler for C++ structured binding support only, don't use this
namespace
std
{
template
<
typename
...
Ts
>
struct
tuple_size
<
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
integral_constant
<
std
::
size_t
,
sizeof
...(
Ts
)
>
{
};
template
<
std
::
size_t
I
,
typename
...
Ts
>
struct
tuple_element
<
I
,
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
tuple_element
<
I
,
std
::
tuple
<
Ts
...
>>
{
};
template
<
typename
...
Ts
>
struct
tuple_size
<
const
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
integral_constant
<
std
::
size_t
,
sizeof
...(
Ts
)
>
{
};
template
<
std
::
size_t
I
,
typename
...
Ts
>
struct
tuple_element
<
I
,
const
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
tuple_element
<
I
,
const
std
::
tuple
<
Ts
...
>>
{
};
}
// namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif
include/ck_tile/core/numeric/bfloat16.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#pragma once
namespace
ck_tile
{
enum
class
bf16_rounding_mode
{
standard
=
0
,
// rtn
truncate_with_nan
,
truncate
,
};
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_raw
(
float
f
,
constant
<
rounding
>
=
{});
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
double_to_bf16_raw
(
double
f
,
constant
<
rounding
>
=
{});
CK_TILE_HOST_DEVICE
constexpr
float
bf16_to_float_raw
(
uint16_t
x
);
CK_TILE_HOST_DEVICE
constexpr
double
bf16_to_double_raw
(
uint16_t
x
);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use __hip_bfloat16 as struct
struct
alignas
(
2
)
bfloat16_t
{
using
raw_type
=
uint16_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
bit_cast
(
raw_type
x
)
{
bfloat16_t
y
;
y
.
data
=
x
;
return
y
;
}
// constructor
constexpr
bfloat16_t
()
:
data
()
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
float
&
x
)
:
data
(
float_to_bf16_raw
(
x
))
{}
// construct from double
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
double
&
x
)
:
data
(
double_to_bf16_raw
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
int
&
x
)
:
data
(
float_to_bf16_raw
(
static_cast
<
float
>
(
x
)))
{}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
unsigned
int
&
x
)
:
data
(
float_to_bf16_raw
(
static_cast
<
float
>
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
bf16_to_float_raw
(
data
);
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
double
()
const
{
return
bf16_to_double_raw
(
data
);
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
bf16_to_float_raw
(
data
));
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
template
<
typename
>
struct
native_t
;
template
<
>
struct
native_t
<
bfloat16_t
>
{
using
type
=
ushort
;
};
using
bf16_t
=
bfloat16_t
;
using
bf16_raw_t
=
typename
bf16_t
::
raw_type
;
#else
using
bfloat16_t
=
ushort
;
using
bf16_t
=
bfloat16_t
;
using
bf16_raw_t
=
uint16_t
;
#endif
// round to nearest
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_rtn_raw
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
if
(
~
u
.
int32
&
0x7f800000
)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u
.
int32
+=
0x7fff
+
((
u
.
int32
>>
16
)
&
1
);
// Round to nearest, round to even
}
else
if
(
u
.
int32
&
0xffff
)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u
.
int32
|=
0x10000
;
// Preserve signaling NaN
}
return
uint16_t
(
u
.
int32
>>
16
);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
return
uint16_t
(
u
.
int32
>>
16
)
|
(
!
(
~
u
.
int32
&
0x7f800000
)
&&
(
u
.
int32
&
0xffff
));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_raw
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
bf16_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_raw
(
float
f
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
return
float_to_bf16_truc_raw
(
f
);
}
template
<
bf16_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
double_to_bf16_raw
(
double
f
,
constant
<
rounding
>
)
{
return
float_to_bf16_raw
(
static_cast
<
float
>
(
f
),
constant
<
rounding
>
{});
}
CK_TILE_HOST_DEVICE
constexpr
float
bf16_to_float_raw
(
uint16_t
x
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
}
CK_TILE_HOST_DEVICE
constexpr
double
bf16_to_double_raw
(
uint16_t
x
)
{
return
static_cast
<
double
>
(
bf16_to_float_raw
(
x
));
}
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
bfloat16_t
float_to_bf16
(
float
f
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bfloat16_t
>
(
float_to_bf16_raw
(
f
,
constant
<
rounding
>
{}));
}
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
bfloat16_t
double_to_bf16
(
double
f
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bfloat16_t
>
(
double_to_bf16_raw
(
f
,
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
constexpr
float
bf16_to_float
(
bfloat16_t
x
)
{
return
bf16_to_float_raw
(
bit_cast
<
uint16_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
constexpr
double
bf16_to_double
(
bfloat16_t
x
)
{
return
static_cast
<
double
>
(
bf16_to_float_raw
(
x
));
}
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bfloat16_t
constexpr
fp16_to_bf16
(
half_t
f
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bfloat16_t
>
(
float_to_bf16_raw
(
static_cast
<
float
>
(
f
),
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
constexpr
half_t
bf16_to_fp16
(
bfloat16_t
x
)
{
return
static_cast
<
fp16_t
>
(
static_cast
<
float
>
(
x
));
}
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
bfloat16_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
min
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x0080
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
lowest
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0xff7f
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
max
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7f7f
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
epsilon
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x1000
));
}
// maximum rounding error
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeeeeee mmmmmmm
// 0 01111110 0000000 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
round_error
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x3f00
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
infinity
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7f80
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
quiet_NaN
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7FFF
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
signaling_NaN
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7FFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
denorm_min
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x0001
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
zero
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0
));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
#endif
// math
CK_TILE_HOST_DEVICE
bfloat16_t
abs
(
const
bfloat16_t
&
x
)
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
bit_cast
<
bf16_raw_t
>
(
x
)
&
0x7fff
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bfloat16_t
&
x
)
{
uint16_t
xx
=
bit_cast
<
bf16_raw_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
}
CK_TILE_DEVICE
bfloat16_t
sqrt
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp2
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
log
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
}
// namespace ck_tile
include/ck_tile/core/numeric/float8.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace
ck_tile
{
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum
class
fp8_rounding_mode
{
standard
=
0
,
stochastic
};
/*
* ______________NANOO_________________ | ______________IEEE________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
uint8_t
float_to_fp8_raw
(
float
,
constant
<
rounding
>
=
{});
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
uint8_t
float_to_bf8_raw
(
float
,
constant
<
rounding
>
=
{});
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
uint8_t
);
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
uint8_t
);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
struct
alignas
(
1
)
float8_e4m3_t
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
float8_e4m3_t
bit_cast
(
raw_type
x
)
{
float8_e4m3_t
y
;
y
.
data
=
x
;
return
y
;
}
// constructor
constexpr
float8_e4m3_t
()
:
data
()
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e4m3_t
(
const
float
&
x
)
:
data
(
float_to_fp8_raw
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e4m3_t
(
const
int
&
x
)
:
data
(
float_to_fp8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e4m3_t
(
const
unsigned
int
&
x
)
:
data
(
float_to_fp8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
fp8_to_float_raw
(
data
);
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
fp8_to_float_raw
(
data
));
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
using
fp8_t
=
float8_e4m3_t
;
using
fp8_raw_t
=
typename
fp8_t
::
raw_type
;
struct
alignas
(
1
)
float8_e5m2_t
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
float8_e5m2_t
bit_cast
(
raw_type
x
)
{
float8_e5m2_t
y
;
y
.
data
=
x
;
return
y
;
}
// constructor
constexpr
float8_e5m2_t
()
:
data
()
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e5m2_t
(
const
float
&
x
)
:
data
(
float_to_bf8_raw
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e5m2_t
(
const
int
&
x
)
:
data
(
float_to_bf8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e5m2_t
(
const
unsigned
int
&
x
)
:
data
(
float_to_bf8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
bf8_to_float_raw
(
data
);
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
bf8_to_float_raw
(
data
));
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
using
bf8_t
=
float8_e5m2_t
;
using
bf8_raw_t
=
typename
bf8_t
::
raw_type
;
template
<
typename
>
struct
native_t
;
template
<
>
struct
native_t
<
fp8_t
>
{
using
type
=
_BitInt
(
8
);
};
template
<
>
struct
native_t
<
bf8_t
>
{
using
type
=
unsigned
_BitInt
(
8
);
};
#else
using
fp8_t
=
_BitInt
(
8
);
using
fp8_raw_t
=
uint8_t
;
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_raw_t
=
uint8_t
;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
// original type exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr
Y
nan_code
=
numeric
<
Y
>::
quiet_NaN
();
// __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
#else
constexpr
Y
nan_code
=
0x80
;
#endif
constexpr
uint32_t
nan_mask
=
numeric_traits
<
X
>::
nan_mask
;
// convert to bitwise
using
T_bitwise
=
typename
numeric_traits
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
head
=
x_bitwise
&
numeric_traits
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
numeric_traits
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
numeric_traits
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
numeric_traits
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
if
constexpr
(
negative_zero_nan
)
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
nan_code
;
}
else
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
0
));
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
if
((
1
<<
in_mant
)
&
mantissa
)
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
out_exponent
=
max_exp
;
}
else
{
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
signed_inf
));
}
}
// check if x is 0.0 or -0.0
if
(
out_exponent
==
0
&&
mantissa
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
))));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
((
sign
<<
(
out_exp
+
out_mant
))
|
(
out_exponent
<<
out_mant
)
|
mantissa
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
run_cast_from_f8
(
X
x
)
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
uint8_t
x_raw
=
__builtin_bit_cast
(
uint8_t
,
x
);
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
numeric_traits
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
numeric_traits
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
numeric_traits
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
numeric_traits
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
numeric_traits
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x_raw
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x_raw
>>
(
in_exp
+
in_mant
);
uint32_t
mantissa
=
x_raw
&
((
1
<<
in_mant
)
-
1
);
int
exponent
=
(
x_raw
&
0x7F
)
>>
in_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
out_exp
-
1
))
-
(
1
<<
(
in_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwise
retval
;
if
constexpr
(
negative_zero_nan
)
{
if
(
x_raw
==
nan_code
)
return
NaN
;
}
else
{
if
(
x_raw
==
nan_code
)
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
numeric_traits
<
Y
>::
mant
==
10
)
&&
(
numeric_traits
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x_raw
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
out_mant
-
in_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
out_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
cast_from_f8
(
X
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
// namespace impl
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_sr_raw
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_sr_raw
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
// clang-format off
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_fp8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_fp8_sr_raw
(
x
);
else
return
fp8_raw_t
{
0
};
}
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_bf8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_bf8_sr_raw
(
x
);
else
return
bf8_raw_t
{
0
};
}
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
fp8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
fp8_t
>
(
x
));
#endif
}
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
bf8_t
>
(
x
));
#endif
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
// clang-format on
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
fp8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
8
;
#else
static
constexpr
int
bias
=
7
;
#endif
};
template
<
>
struct
numeric_traits
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
16
;
#else
static
constexpr
int
bias
=
15
;
// IEEE
#endif
};
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
fp8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x08
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
lowest
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xff
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
max
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7f
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
epsilon
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x20
));
}
// maximum rounding error
// bin : 7 6543 210
// bits: s eeee mmm
// 0 0110 000 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
round_error
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x30
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
infinity
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x80
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
quiet_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x80
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
signaling_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x80
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
denorm_min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
zero
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0
));
}
};
template
<
>
struct
numeric
<
bf8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x04
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
lowest
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xff
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
max
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7f
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
epsilon
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x34
));
}
// maximum rounding error
// bin : 7 65432 10
// bits: s eeeee mm
// 0 01110 00 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
round_error
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x38
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
infinity
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x80
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
quiet_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x80
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
signaling_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x80
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
denorm_min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
zero
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bf8_t
)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t
abs
(
const
fp8_t
&
x
)
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
fp8_raw_t
>
(
x
)
&
0x7f
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
fp8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp2
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST_DEVICE
bf8_t
abs
(
const
bf8_t
&
x
)
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
bf8_raw_t
>
(
x
)
&
0x7f
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bf8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp2
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
}
// namespace ck_tile
include/ck_tile/core/numeric/half.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace
ck_tile
{
using
fp16_hip_t
=
_Float16
;
// most of hip internal function use this type
using
fp16_raw_t
=
uint16_t
;
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_float_hip
(
const
fp16_hip_t
&
x
);
CK_TILE_HOST_DEVICE
constexpr
double
fp16_to_double_hip
(
const
fp16_hip_t
&
x
);
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
float_to_fp16_hip
(
const
float
&
x
);
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
double_to_fp16_hip
(
const
double
&
x
);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use fp16_hip_t as interchangable data type for float16
struct
alignas
(
2
)
half_t
{
using
raw_type
=
fp16_raw_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
half_t
bit_cast
(
raw_type
x
)
{
half_t
y
;
y
.
data
=
x
;
return
y
;
}
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
to_fp16
()
const
{
return
ck_tile
::
bit_cast
<
fp16_hip_t
>
(
data
);
}
// constructor
constexpr
half_t
()
:
data
{}
{}
// construct from HIP half
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
fp16_hip_t
&
x
)
:
data
(
ck_tile
::
bit_cast
<
raw_type
>
(
x
))
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
float
&
x
)
:
half_t
(
float_to_fp16_hip
(
x
))
{}
// construct from double
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
double
&
x
)
:
half_t
(
double_to_fp16_hip
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
int
&
x
)
:
half_t
(
static_cast
<
fp16_hip_t
>
(
__int2half_rn
(
x
)))
{}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
unsigned
int
&
x
)
:
half_t
(
static_cast
<
fp16_hip_t
>
(
__uint2half_rn
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
fp16_to_float_hip
(
to_fp16
());
}
// cast to double
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
double
()
const
{
return
fp16_to_double_hip
(
to_fp16
());
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
fp16_to_float_hip
(
to_fp16
()));
}
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
fp16_hip_t
()
const
{
return
ck_tile
::
bit_cast
<
fp16_hip_t
>
(
data
);
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
template
<
typename
>
struct
native_t
;
template
<
>
struct
native_t
<
half_t
>
{
using
type
=
_Float16
;
};
using
fp16_t
=
half_t
;
using
fp16_raw_t
=
typename
half_t
::
raw_type
;
#else
using
fp16_t
=
_Float16
;
using
half_t
=
_Float16
;
using
fp16_raw_t
=
ushort
;
#endif
// conversions
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_float_hip
(
const
fp16_hip_t
&
x
)
{
// return __half2float(x);
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
double
fp16_to_double_hip
(
const
fp16_hip_t
&
x
)
{
return
static_cast
<
double
>
(
fp16_to_float_hip
(
x
));
}
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
float_to_fp16_hip
(
const
float
&
x
)
{
return
__float2half
(
x
);
// return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
double_to_fp16_hip
(
const
double
&
x
)
{
// return __float2half(x);
return
static_cast
<
fp16_hip_t
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_float
(
const
half_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_double
(
const
half_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
half_t
float_to_fp16
(
const
float
&
x
)
{
return
static_cast
<
half_t
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
half_t
double_to_fp16
(
const
double
&
x
)
{
return
static_cast
<
half_t
>
(
x
);
}
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
half_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
half_t
min
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x0400
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
lowest
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0xFBFF
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
max
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7BFF
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
half_t
epsilon
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x1800
));
}
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeee mmmmmmmmmm
// 0 01110 0000000000 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
half_t
round_error
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x3800
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
infinity
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7C00
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
half_t
quiet_NaN
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7FFF
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
half_t
signaling_NaN
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7FFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
denorm_min
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x0001
));
}
CK_TILE_HOST_DEVICE
static
constexpr
half_t
zero
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0
));
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// arithmetic
CK_TILE_DEVICE
bool
operator
==
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__heq
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
!=
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hne
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
<
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hlt
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
<=
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hle
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
>
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hgt
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
>=
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hge
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
#if 0
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST
,
half_t
)
#endif
// math
CK_TILE_HOST_DEVICE
half_t
abs
(
const
half_t
&
x
)
{
return
bit_cast
<
half_t
>
(
x
.
get
()
&
0x7fff
);
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
half_t
&
x
)
{
uint16_t
xx
=
x
.
get
();
return
(
xx
&
0x7FFF
)
>
0x7C00
;
}
CK_TILE_DEVICE
half_t
sqrt
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp2
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
log
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/integer.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace
ck_tile
{
using
index_t
=
int32_t
;
using
long_index_t
=
int64_t
;
using
int8_t
=
int8_t
;
}
// namespace ck_tile
include/ck_tile/core/numeric/integral_constant.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace
ck_tile
{
template
<
auto
v
>
struct
constant
{
using
value_type
=
decltype
(
v
);
using
type
=
constant
;
// using injected-class-name
static
constexpr
value_type
value
=
v
;
CK_TILE_HOST_DEVICE
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
CK_TILE_HOST_DEVICE
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
true
;
}
};
template
<
typename
T
,
T
v
>
struct
integral_constant
:
constant
<
v
>
{
using
value_type
=
T
;
using
type
=
integral_constant
;
// using injected-class-name
static
constexpr
T
value
=
v
;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template
<
index_t
v
>
using
number
=
constant
<
v
>
;
template
<
long_index_t
v
>
using
long_number
=
constant
<
v
>
;
template
<
bool
b
>
using
bool_constant
=
constant
<
b
>
;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
{ \
return constant<(OP x)>{}; \
}
#define CK_TILE_BINARY_OP(OP) \
template <auto x, auto y> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
{ \
return constant<(x OP y)>{}; \
}
CK_TILE_LEFT_UNARY_OP
(
+
)
CK_TILE_LEFT_UNARY_OP
(
-
)
CK_TILE_LEFT_UNARY_OP
(
~
)
CK_TILE_LEFT_UNARY_OP
(
!
)
CK_TILE_LEFT_UNARY_OP
(
*
)
CK_TILE_BINARY_OP
(
+
)
CK_TILE_BINARY_OP
(
-
)
CK_TILE_BINARY_OP
(
*
)
CK_TILE_BINARY_OP
(
/
)
CK_TILE_BINARY_OP
(
%
)
CK_TILE_BINARY_OP
(
&
)
CK_TILE_BINARY_OP
(
|
)
CK_TILE_BINARY_OP
(
^
)
CK_TILE_BINARY_OP
(
<<
)
CK_TILE_BINARY_OP
(
>>
)
CK_TILE_BINARY_OP
(
&&
)
CK_TILE_BINARY_OP
(
||
)
CK_TILE_BINARY_OP
(
==
)
CK_TILE_BINARY_OP
(
!=
)
CK_TILE_BINARY_OP
(
>
)
CK_TILE_BINARY_OP
(
<
)
CK_TILE_BINARY_OP
(
>=
)
CK_TILE_BINARY_OP
(
<=
)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
}
// namespace ck_tile
include/ck_tile/core/numeric/math.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <type_traits>
#include <stdint.h>
#include <cmath>
namespace
ck_tile
{
template
<
typename
Scale
,
Scale
lhs
>
struct
scales_c
{
template
<
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Right
&
rhs
)
const
->
decltype
(
lhs
*
rhs
)
{
return
lhs
*
rhs
;
}
};
template
<
typename
Scale
>
struct
scales
{
static_assert
(
std
::
is_copy_constructible_v
<
Scale
>
);
CK_TILE_HOST_DEVICE
constexpr
explicit
scales
(
Scale
lhs
)
:
lhs_
(
lhs
)
{}
template
<
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Right
&
rhs
)
const
->
decltype
(
std
::
declval
<
const
Scale
&>
()
*
rhs
)
{
return
lhs_
*
rhs
;
}
private:
Scale
lhs_
;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template
<
typename
Scale
>
__host__
__device__
scales
(
Scale
)
->
scales
<
Scale
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
plus
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
+
rhs
)
{
return
lhs
+
rhs
;
}
};
template
<
>
struct
plus
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
+
rhs
)
{
return
lhs
+
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
plus
()
->
plus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
minus
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
-
rhs
)
{
return
lhs
-
rhs
;
}
};
template
<
>
struct
minus
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
-
rhs
)
{
return
lhs
-
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
minus
()
->
minus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
multiplies
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
*
rhs
)
{
return
lhs
*
rhs
;
}
};
template
<
>
struct
multiplies
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
*
rhs
)
{
return
lhs
*
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
multiplies
()
->
multiplies
<
void
,
void
>
;
template
<
typename
T
>
struct
maximize
{
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>=
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
minimize
{
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<=
b
?
a
:
b
;
}
};
template
<
typename
T
>
struct
integer_divide_ceiler
{
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
static_assert
(
std
::
is_same
<
T
,
index_t
>
{}
||
std
::
is_same
<
T
,
int
>
{},
"wrong type"
);
return
(
a
+
b
-
number
<
1
>
{})
/
b
;
}
};
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
integer_divide_floor
(
X
x
,
Y
y
)
{
return
x
/
y
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
{
return
(
x
+
y
-
number
<
1
>
{})
/
y
;
}
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
{
return
y
*
integer_divide_ceil
(
x
,
y
);
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
max
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
constexpr
float
max
(
float
x
,
float
y
)
{
return
__builtin_fmaxf
(
x
,
y
);
// can resultin v_max3_f32
}
template
<
>
CK_TILE_DEVICE
constexpr
double
max
(
double
x
,
double
y
)
{
return
__builtin_fmax
(
x
,
y
);
// maybe still v_max3_f32
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
index_t
max
(
number
<
X
>
,
index_t
y
)
{
return
X
>
y
?
X
:
y
;
}
template
<
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
index_t
max
(
index_t
x
,
number
<
Y
>
)
{
return
x
>
Y
?
x
:
Y
;
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
CK_TILE_HOST
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
T
min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
>
CK_TILE_DEVICE
constexpr
float
min
(
float
x
,
float
y
)
{
return
__builtin_fminf
(
x
,
y
);
}
template
<
>
CK_TILE_DEVICE
constexpr
double
min
(
double
x
,
double
y
)
{
return
__builtin_fmin
(
x
,
y
);
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
index_t
min
(
number
<
X
>
,
index_t
y
)
{
return
X
<
y
?
X
:
y
;
}
template
<
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
index_t
min
(
index_t
x
,
number
<
Y
>
)
{
return
x
<
Y
?
x
:
Y
;
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
clamp
(
const
T
&
x
,
const
T
&
lowerbound
,
const
T
&
upperbound
)
{
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
CK_TILE_HOST
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
CK_TILE_DEVICE
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
if
(
x
<
0
)
{
return
gcd
(
-
x
,
y
);
}
else
if
(
y
<
0
)
{
return
gcd
(
x
,
-
y
);
}
else
if
(
x
==
y
||
x
==
0
)
{
return
y
;
}
else
if
(
y
==
0
)
{
return
x
;
}
else
if
(
x
>
y
)
{
return
gcd
(
x
%
y
,
y
);
}
else
{
return
gcd
(
x
,
y
%
x
);
}
}
template
<
index_t
X
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
gcd
(
number
<
X
>
,
number
<
Y
>
)
{
constexpr
auto
r
=
gcd
(
X
,
Y
);
return
number
<
r
>
{};
}
template
<
typename
X
,
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
return
gcd
(
x
,
gcd
(
ys
...));
}
// least common multiple
template
<
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
lcm
(
X
x
,
Y
y
)
{
return
(
x
*
y
)
/
gcd
(
x
,
y
);
}
template
<
typename
X
,
typename
...
Ys
,
typename
std
::
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
lcm
(
X
x
,
Ys
...
ys
)
{
return
lcm
(
x
,
lcm
(
ys
...));
}
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
equal
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
==
rhs
)
{
return
lhs
==
rhs
;
}
};
template
<
>
struct
equal
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
==
rhs
)
{
return
lhs
==
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
equal
()
->
equal
<
void
,
void
>
;
template
<
>
struct
equal
<
float
,
float
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
float
lhs
,
float
rhs
)
const
{
return
bit_cast
<
uint32_t
>
(
lhs
)
==
bit_cast
<
uint32_t
>
(
rhs
);
}
};
template
<
>
struct
equal
<
double
,
double
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
double
lhs
,
double
rhs
)
const
{
return
bit_cast
<
uint64_t
>
(
lhs
)
==
bit_cast
<
uint64_t
>
(
rhs
);
}
};
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
less
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<
rhs
)
{
return
lhs
<
rhs
;
}
};
template
<
>
struct
less
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<
rhs
)
{
return
lhs
<
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
less
()
->
less
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
less_equal
{
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<=
rhs
)
{
return
lhs
<=
rhs
;
}
};
template
<
>
struct
less_equal
<
void
,
void
>
{
template
<
typename
Left
,
typename
Right
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
Left
&
lhs
,
const
Right
&
rhs
)
const
->
decltype
(
lhs
<=
rhs
)
{
return
lhs
<=
rhs
;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__
__device__
less_equal
()
->
less_equal
<
void
,
void
>
;
template
<
>
struct
less_equal
<
float
,
float
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
float
lhs
,
float
rhs
)
const
{
return
lhs
<
rhs
||
bit_cast
<
uint32_t
>
(
lhs
)
==
bit_cast
<
uint32_t
>
(
rhs
);
}
};
template
<
>
struct
less_equal
<
double
,
double
>
{
CK_TILE_HOST_DEVICE
constexpr
bool
operator
()(
double
lhs
,
double
rhs
)
const
{
return
lhs
<
rhs
||
bit_cast
<
uint64_t
>
(
lhs
)
==
bit_cast
<
uint64_t
>
(
rhs
);
}
};
CK_TILE_HOST_DEVICE
constexpr
int32_t
next_power_of_two
(
int32_t
x
)
{
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
return
1
<<
(
32
-
clz
(
x
-
1
));
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
next_power_of_two
()
{
constexpr
index_t
y
=
next_power_of_two
(
X
);
return
number
<
y
>
{};
}
template
<
index_t
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
next_power_of_two
(
number
<
X
>
)
{
constexpr
index_t
y
=
next_power_of_two
(
X
);
return
number
<
y
>
{};
}
CK_TILE_HOST_DEVICE
constexpr
int32_t
integer_log2_floor
(
int32_t
x
)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return
31
-
__builtin_clz
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
bool
is_power_of_two_integer
(
int32_t
x
)
{
// TODO: x need to be 1 ~ 0x7fffffff
return
x
==
(
1
<<
integer_log2_floor
(
x
));
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
template
<
typename
T
>
struct
log2e
;
template
<
>
struct
log2e
<
double
>
{
static
constexpr
double
value
=
C_LOG2E
;
};
template
<
>
struct
log2e
<
float
>
{
static
constexpr
float
value
=
C_LOG2E
;
};
template
<
typename
T
=
double
>
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
// math
CK_TILE_HOST_DEVICE
float
abs
(
const
float
&
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
float
&
x
)
{
uint32_t
xx
=
bit_cast
<
uint32_t
>
(
x
);
return
(
xx
&
0x7fffffff
)
>
0x7F800000
;
}
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__expf
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
CK_TILE_DEVICE
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
CK_TILE_HOST
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
CK_TILE_DEVICE
float
log
(
float
x
)
{
return
__logf
(
x
);
};
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
}
// namespace ck_tile
include/ck_tile/core/numeric/numeric.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace
ck_tile
{
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
//
template
<
typename
T
>
struct
numeric
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
T
min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
T
lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
T
max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
T
epsilon
()
{
return
std
::
numeric_limits
<
T
>::
epsilon
();
}
// maximum rounding error
CK_TILE_HOST_DEVICE
static
constexpr
T
round_error
()
{
return
std
::
numeric_limits
<
T
>::
round_error
();
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
T
infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
T
quiet_NaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
T
signaling_NaN
()
{
return
std
::
numeric_limits
<
T
>::
signaling_NaN
();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
T
denorm_min
()
{
return
std
::
numeric_limits
<
T
>::
denorm_min
();
}
CK_TILE_HOST_DEVICE
static
constexpr
T
zero
()
{
return
static_cast
<
T
>
(
0
);
}
CK_TILE_HOST_DEVICE
static
constexpr
T
one
()
{
return
static_cast
<
T
>
(
1
);
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
CK_TILE_HOST_DEVICE
static
constexpr
T
log2e
()
{
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
)
{
return
static_cast
<
T
>
(
C_LOG2E
);
}
else
{
return
0
;
// TODO: integer?
}
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
}
// namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}
include/ck_tile/core/numeric/type_convert.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace
ck_tile
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
Y
>
type_convert
(
const
X
&
x
)
{
return
static_cast
<
Y
>
(
x
);
}
#else
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
using
non_const_y
=
std
::
remove_const_t
<
Y
>
;
using
non_const_x
=
std
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
non_const_y
,
non_const_x
>
(
x
));
}
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x); \
}
CK_TILE_TYPE_CONVERT
(
float
,
float
,
fp16_t
,
fp16
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
bf16_t
,
bf16
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
fp8_t
,
fp8
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
bf8_t
,
bf8
)
CK_TILE_TYPE_CONVERT
(
fp16_t
,
fp16
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf16_t
,
bf16
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
#undef CK_TILE_TYPE_CONVERT
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template
<
typename
T
>
struct
native_t
{
using
type
=
remove_cvref_t
<
T
>
;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace
impl
{
template
<
typename
T_
,
index_t
N_
>
struct
ext_vector
{
static
constexpr
index_t
N
=
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
T_
>>::
type
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
template
<
typename
V_
,
index_t
Vs_
,
index_t
N_
>
struct
ext_vector
<
V_
__attribute__
((
ext_vector_type
(
Vs_
))),
N_
>
{
static
constexpr
index_t
N
=
Vs_
*
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
V_
>>::
type
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
}
// namespace impl
template
<
typename
T
,
index_t
N
>
using
ext_vector_t
=
typename
impl
::
ext_vector
<
T
,
N
>::
type
;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template
<
typename
T
>
struct
vector_traits
{
using
scalar_type
=
remove_cvref_t
<
T
>
;
static
constexpr
index_t
vector_size
=
1
;
};
// specialization for ext_vector_type()
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
{
using
scalar_type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
typename
X
,
typename
Y
>
using
has_same_scalar_type
=
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
Y
>>::
scalar_type
>
;
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using
fp64_t
=
double
;
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
// fp32
using
fp32_t
=
float
;
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
using
fp32x16_t
=
float
__attribute__
((
ext_vector_type
(
16
)));
using
fp32x32_t
=
float
__attribute__
((
ext_vector_type
(
32
)));
using
fp32x64_t
=
float
__attribute__
((
ext_vector_type
(
64
)));
// fp16
// using fp16_t = ...
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
using
fp16x16_t
=
_Float16
__attribute__
((
ext_vector_type
(
16
)));
using
fp16x32_t
=
_Float16
__attribute__
((
ext_vector_type
(
32
)));
using
fp16x64_t
=
_Float16
__attribute__
((
ext_vector_type
(
64
)));
// bf16
// using bf16_t = ...
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
using
bf16x16_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
16
)));
using
bf16x32_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
32
)));
using
bf16x64_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
64
)));
// i32
// using int32_t = ...
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
using
int32x16_t
=
int32_t
__attribute__
((
ext_vector_type
(
16
)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
using
int16x16_t
=
int16_t
__attribute__
((
ext_vector_type
(
16
)));
using
int16x32_t
=
int16_t
__attribute__
((
ext_vector_type
(
32
)));
using
int16x64_t
=
int16_t
__attribute__
((
ext_vector_type
(
64
)));
// u16
// using uint16_t
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint16x16_t
=
uint16_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint16x32_t
=
uint16_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint16x64_t
=
uint16_t
__attribute__
((
ext_vector_type
(
64
)));
// i8
// using int8_t
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
using
int8x16_t
=
int8_t
__attribute
((
ext_vector_type
(
16
)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x16_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
16
)));
using
fp8x32_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
32
)));
using
fp8x64_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
64
)));
// bf8
// using bf8_t
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x16_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
16
)));
using
bf8x32_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
32
)));
using
bf8x64_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
64
)));
#else
// f8
// using fp8_t
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x16_t
=
fp8_t
__attribute
((
ext_vector_type
(
16
)));
using
fp8x32_t
=
fp8_t
__attribute
((
ext_vector_type
(
32
)));
using
fp8x64_t
=
fp8_t
__attribute
((
ext_vector_type
(
64
)));
// bf8
// using bf8_t
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x16_t
=
bf8_t
__attribute
((
ext_vector_type
(
16
)));
using
bf8x32_t
=
bf8_t
__attribute
((
ext_vector_type
(
32
)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
}
// namespace ck_tile
include/ck_tile/core/tensor/buffer_view.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
// buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
template
<
address_space_enum
BufferAddressSpace
,
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
Coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
struct
buffer_view
;
// Address Space: generic
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
>
struct
buffer_view
<
address_space_enum
::
generic
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
::
coherence_default
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
generic
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: generic, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
// Address Space: Global
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
Coherence
>
struct
buffer_view
<
address_space_enum
::
global
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
Coherence
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
global
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
#if CK_TILE_USE_AMD_BUFFER_LOAD
bool
constexpr
use_amd_buffer_addressing
=
true
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
,
invalid_element_value_
);
}
}
else
{
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
this
->
template
atomic_max
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
#if CK_TILE_USE_AMD_BUFFER_STORE
bool
constexpr
use_amd_buffer_addressing
=
true
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
{
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set_raw
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
{
if
(
is_valid_element
)
{
atomic_add
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
}
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_max
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
double
>
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
if
(
is_valid_element
)
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: Global, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
// Address Space: LDS
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
>
struct
buffer_view
<
address_space_enum
::
lds
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
::
coherence_default
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
lds
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
scalar_per_t_vector
*
scalar_per_x_vector
>
;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto
rtn
=
*
c_style_pointer_cast
<
const
buf_t
*>
(
&
p_data_
[
i
]);
return
bit_cast
<
X
>
(
rtn
);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
#if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
bool
constexpr
workaround_int8_ds_write_issue
=
true
;
#else
bool
constexpr
workaround_int8_ds_write_issue
=
false
;
#endif
if
constexpr
(
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
int8_t
>::
value
&&
workaround_int8_ds_write_issue
)
{
if
(
is_valid_element
)
{
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
static_assert
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x4_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x4_t
*>
(
&
x
);
}
}
}
else
{
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
scalar_per_t_vector
*
scalar_per_x_vector
>
;
*
c_style_pointer_cast
<
buf_t
*>
(
&
p_data_
[
i
])
=
reinterpret_cast
<
const
buf_t
&>
(
x
);
#endif
}
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: Lds, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
// Address Space: Vgpr
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
>
struct
buffer_view
<
address_space_enum
::
vgpr
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
::
coherence_default
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
vgpr
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: Vgpr, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
template
<
address_space_enum
BufferAddressSpace
,
amd_buffer_coherence_enum
Coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
typename
T
,
typename
BufferSizeType
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_buffer_view
(
T
*
p
,
BufferSizeType
buffer_size
)
{
return
buffer_view
<
BufferAddressSpace
,
T
,
BufferSizeType
,
true
,
Coherence
>
{
p
,
buffer_size
};
}
template
<
address_space_enum
BufferAddressSpace
,
amd_buffer_coherence_enum
Coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
typename
T
,
typename
BufferSizeType
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
remove_cvref_t
<
T
>,
remove_cvref_t
<
X
>>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_buffer_view
(
T
*
p
,
BufferSizeType
buffer_size
,
X
invalid_element_value
)
{
return
buffer_view
<
BufferAddressSpace
,
T
,
BufferSizeType
,
false
,
Coherence
>
{
p
,
buffer_size
,
invalid_element_value
};
}
}
// namespace ck_tile
include/ck_tile/core/tensor/load_tile.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
{
return
tile_window
.
async_load
(
lds_tile
);
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
{
return
null_tensor
{};
}
template
<
typename
T
,
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
/*null_tile*/
,
const
null_tile_window
<
WindowLengths
>&
)
{
}
}
// namespace ck_tile
include/ck_tile/core/tensor/null_tensor.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
struct
null_tensor
{
};
}
// namespace ck_tile
include/ck_tile/core/tensor/null_tile_window.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
namespace
ck_tile
{
// placeholder type if we want to opt-out a tile window parameter
template
<
typename
WindowLengths_
>
struct
null_tile_window
{
using
BottomTensorView
=
null_tensor_view
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
BottomTensorIndex
=
array
<
index_t
,
WindowLengths
::
size
()
>
;
CK_TILE_DEVICE
constexpr
null_tile_window
()
=
default
;
CK_TILE_DEVICE
constexpr
null_tile_window
(
const
WindowLengths
&
window_lengths
)
:
window_lengths_
{
window_lengths
}
{
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
null_tensor_view
{};
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
WindowLengths
window_lengths_
;
};
// utility to check if this is a Null Tile Window
namespace
impl
{
template
<
typename
>
struct
is_null_tile_window
:
public
std
::
false_type
{
};
template
<
typename
T
>
struct
is_null_tile_window
<
null_tile_window
<
T
>>
:
public
std
::
true_type
{
};
}
// namespace impl
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
auto
is_null_tile_window
(
const
T
&
)
{
return
impl
::
is_null_tile_window
<
remove_cvref_t
<
T
>>::
value
;
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_null_tile_window
(
const
WindowLengths
&
window_lengths
)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
,
typename
...
Ts
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
null_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
multi_index
<
WindowLengths
::
size
()
>&
/*origin*/
,
Ts
&&
...)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
void
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
const
typename
null_tile_window
<
WindowLengths
>::
BottomTensorIndex
&
)
{
}
}
// namespace ck_tile
include/ck_tile/core/tensor/shuffle_tile.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace
ck_tile
{
namespace
detail
{
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
shuffle_tile_impl_in_thread
(
OutTensor
&
out_tensor
,
const
InTensor
&
in_tensor
)
{
constexpr
auto
I0
=
number
<
0
>
{};
using
DataType
=
typename
InTensor
::
DataType
;
constexpr
auto
y_in_desc
=
InTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
constexpr
auto
y_out_desc
=
OutTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
// y_dim_out_to_in
constexpr
auto
get_rh_major_minor_to_y
=
[](
auto
dstr_tensor
)
{
using
DstrEncode
=
typename
decltype
(
dstr_tensor
.
get_tile_distribution
())
::
DstrEncode
;
map
<
array
<
index_t
,
2
>
,
index_t
>
rh_major_minor_to_y_
;
static_for
<
0
,
DstrEncode
::
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_major
=
DstrEncode
::
ys_to_rhs_major_
[
i
];
constexpr
index_t
rh_minor
=
DstrEncode
::
ys_to_rhs_minor_
[
i
];
rh_major_minor_to_y_
({
rh_major
,
rh_minor
})
=
i
;
});
return
rh_major_minor_to_y_
;
};
constexpr
auto
rh_major_minor_to_y_in
=
get_rh_major_minor_to_y
(
InTensor
{});
constexpr
auto
rh_major_minor_to_y_out
=
get_rh_major_minor_to_y
(
OutTensor
{});
constexpr
auto
y_dim_out_to_in
=
[
&
]
{
map
<
index_t
,
index_t
>
y_dim_out_to_in_
;
for
(
const
auto
&
[
rh_major_minor
,
y_out
]
:
rh_major_minor_to_y_out
)
{
y_dim_out_to_in_
(
y_out
)
=
rh_major_minor_to_y_in
[
rh_major_minor
];
}
return
y_dim_out_to_in_
;
}();
//
constexpr
index_t
NDimY
=
InTensor
::
get_tile_distribution
().
get_num_of_dimension_y
();
constexpr
auto
y_lengths
=
to_sequence
(
y_in_desc
.
get_lengths
());
// input and output vector dim in the order of input Y dims
constexpr
index_t
y_dim_vec_in
=
NDimY
-
1
;
constexpr
index_t
y_dim_vec_out
=
y_dim_out_to_in
[
NDimY
-
1
];
// vector lengths
constexpr
index_t
vec_length_in
=
y_lengths
[
y_dim_vec_in
];
constexpr
index_t
vec_length_out
=
y_lengths
[
y_dim_vec_out
];
// # of vectors
constexpr
index_t
num_vec_in
=
vec_length_out
;
constexpr
index_t
num_vec_out
=
vec_length_in
;
using
InVec
=
array
<
DataType
,
vec_length_in
>
;
using
OutVec
=
array
<
DataType
,
vec_length_out
>
;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
y_dim_vec_in
or
i
==
y_dim_vec_out
)
?
y_lengths
[
i
]
:
1
;
},
number
<
NDimY
>
{});
constexpr
auto
scalars_per_access
=
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY
);
using
SFC_Y
=
space_filling_curve
<
decltype
(
y_lengths
),
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
,
decltype
(
scalars_per_access
)
>
;
constexpr
index_t
num_access
=
SFC_Y
::
get_num_of_access
();
static_assert
(
num_access
>
0
,
"wrong! num_access should be larger than 0"
);
// in/out vectors to be transposed
thread_buffer
<
InVec
,
num_vec_in
>
in_vectors
;
thread_buffer
<
OutVec
,
num_vec_out
>
out_vectors
;
// loop over SFC and do transpose
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
// data index [y0, y1, ...] in the order of input tensor
constexpr
auto
idx_y_start
=
SFC_Y
::
get_index
(
iAccess
);
// get input vectors
static_for
<
0
,
num_vec_in
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_in
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_out
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
index_t
in_offset
=
y_in_desc
.
calculate_offset
(
idx_y_in
);
static_assert
(
in_offset
%
vec_length_in
==
0
);
in_vectors
(
i
).
template
get_as
<
InVec
>()(
I0
)
=
in_tensor
.
get_thread_buffer
()
.
template
get_as
<
InVec
>()[
number
<
in_offset
/
vec_length_in
>
{}];
});
// transpose
transpose_vectors
<
DataType
,
num_vec_in
,
num_vec_out
>
{}(
in_vectors
,
out_vectors
);
// set output vectors
static_for
<
0
,
num_vec_out
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_out_tmp
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_in
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
auto
idx_y_out
=
container_reorder_given_new2old
(
idx_y_out_tmp
,
y_dim_out_to_in
);
constexpr
index_t
out_offset
=
y_out_desc
.
calculate_offset
(
idx_y_out
);
static_assert
(
out_offset
%
vec_length_out
==
0
);
out_tensor
.
get_thread_buffer
().
template
set_as
<
OutVec
>(
number
<
out_offset
/
vec_length_out
>
{},
out_vectors
[
i
].
template
get_as
<
OutVec
>()[
I0
]);
});
});
}
}
// namespace detail
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
shuffle_tile
(
OutTensor
&
out
,
const
InTensor
&
in
)
{
using
InDataType
=
typename
InTensor
::
DataType
;
using
OutDataType
=
typename
OutTensor
::
DataType
;
using
InDstrEncode
=
typename
InTensor
::
StaticTileDistribution
::
DstrEncode
;
using
OutDstrEncode
=
typename
OutTensor
::
StaticTileDistribution
::
DstrEncode
;
// type convert
const
auto
in_tmp
=
tile_elementwise_in
(
type_convert
<
OutDataType
,
InDataType
>
,
in
);
// shuffle
if
constexpr
(
InDstrEncode
::
rs_lengths_
==
OutDstrEncode
::
rs_lengths_
&&
InDstrEncode
::
hs_lengthss_
==
OutDstrEncode
::
hs_lengthss_
&&
InDstrEncode
::
ps_to_rhss_major_
==
OutDstrEncode
::
ps_to_rhss_major_
&&
InDstrEncode
::
ps_to_rhss_minor_
==
OutDstrEncode
::
ps_to_rhss_minor_
&&
InDstrEncode
::
NDimY
==
OutDstrEncode
::
NDimY
)
{
detail
::
shuffle_tile_impl_in_thread
(
out
,
in_tmp
);
}
else
{
// NOT implemented
}
}
}
// namespace ck_tile
include/ck_tile/core/tensor/slice_tile.hpp
0 → 100644
View file @
4396a224
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
get_slice_tile
(
const
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
TileWindow
=
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>
;
// NOTE: This API will override the origin of the tile window!
static_assert
(
sizeof
...(
SliceBegins
)
==
sizeof
...(
SliceEnds
));
static_assert
(
sizeof
...(
SliceBegins
)
==
TileWindow
::
get_num_of_dimension
());
constexpr
auto
slice_lengths
=
slice_ends
-
slice_begins
;
return
make_tile_window
(
tile
.
get_bottom_tensor_view
(),
sequence_to_tuple_of_number
(
slice_lengths
),
to_multi_index
(
slice_begins
));
}
template
<
typename
DataType_
,
typename
StaticTileDistribution_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
get_slice_tile
(
const
static_distributed_tensor
<
DataType_
,
StaticTileDistribution_
>&
tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
Distribution
=
remove_cvref_t
<
StaticTileDistribution_
>
;
constexpr
auto
sliced_dstr_yidx_ylen
=
detail
::
slice_distribution_from_x
(
Distribution
{},
slice_begins
,
slice_ends
);
constexpr
auto
sliced_dstr
=
sliced_dstr_yidx_ylen
.
template
at
<
0
>();
constexpr
auto
sliced_y_origins
=
sliced_dstr_yidx_ylen
.
template
at
<
1
>();
constexpr
auto
sliced_y_lengths
=
sliced_dstr_yidx_ylen
.
template
at
<
2
>();
auto
sliced_tensor
=
make_static_distributed_tensor
<
DataType
>
(
sliced_dstr
);
sliced_tensor
.
get_thread_buffer
()
=
tile
.
get_y_sliced_thread_data
(
sliced_y_origins
,
sliced_y_lengths
);
return
sliced_tensor
;
}
template
<
typename
DstDataType_
,
typename
DstStaticTileDistribution_
,
typename
SrcDataType_
,
typename
SrcStaticTileDistribution_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
set_slice_tile
(
static_distributed_tensor
<
DstDataType_
,
DstStaticTileDistribution_
>&
dst_tile
,
const
static_distributed_tensor
<
SrcDataType_
,
SrcStaticTileDistribution_
>&
src_tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
DstDistribution
=
remove_cvref_t
<
DstStaticTileDistribution_
>
;
constexpr
auto
sliced_dstr_yidx_ylen
=
detail
::
slice_distribution_from_x
(
DstDistribution
{},
slice_begins
,
slice_ends
);
constexpr
auto
sliced_dstr
=
sliced_dstr_yidx_ylen
.
template
at
<
0
>();
constexpr
auto
sliced_y_origins
=
sliced_dstr_yidx_ylen
.
template
at
<
1
>();
constexpr
auto
sliced_y_lengths
=
sliced_dstr_yidx_ylen
.
template
at
<
2
>();
static_assert
(
std
::
is_same_v
<
decltype
(
sliced_dstr
),
DstDistribution
>
,
"wrong!"
);
dst_tile
.
SetSlicedThreadData
(
sliced_y_origins
,
sliced_y_lengths
,
src_tile
.
get_thread_buffer
());
}
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment