Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
20ddaeba
Commit
20ddaeba
authored
Apr 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c5f1cdf7
43879b89
Changes
236
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
9083 additions
and
0 deletions
+9083
-0
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+1672
-0
include/ck_tile/core/algorithm/space_filling_curve.hpp
include/ck_tile/core/algorithm/space_filling_curve.hpp
+166
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+2031
-0
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+93
-0
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+62
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+156
-0
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+251
-0
include/ck_tile/core/container/container_helper.hpp
include/ck_tile/core/container/container_helper.hpp
+499
-0
include/ck_tile/core/container/map.hpp
include/ck_tile/core/container/map.hpp
+164
-0
include/ck_tile/core/container/meta_data_buffer.hpp
include/ck_tile/core/container/meta_data_buffer.hpp
+99
-0
include/ck_tile/core/container/multi_index.hpp
include/ck_tile/core/container/multi_index.hpp
+100
-0
include/ck_tile/core/container/sequence.hpp
include/ck_tile/core/container/sequence.hpp
+1114
-0
include/ck_tile/core/container/span.hpp
include/ck_tile/core/container/span.hpp
+78
-0
include/ck_tile/core/container/statically_indexed_array.hpp
include/ck_tile/core/container/statically_indexed_array.hpp
+41
-0
include/ck_tile/core/container/thread_buffer.hpp
include/ck_tile/core/container/thread_buffer.hpp
+165
-0
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+781
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+342
-0
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+871
-0
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+385
-0
include/ck_tile/core/numeric/integer.hpp
include/ck_tile/core/numeric/integer.hpp
+13
-0
No files found.
Too many changes to show.
To preserve performance only
236 of 236+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/algorithm/coordinate_transform.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
namespace
ck_tile
{
enum
struct
coord_transform_enum
{
undefined
,
pass_through
,
pad
,
embed
,
merge
,
unmerge
,
replicate
,
xor_t
,
offset
,
};
template
<
index_t
NDimLow
,
index_t
NDimUp
>
struct
base_transform
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
undefined
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_lower_dimension
()
{
return
NDimLow
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_upper_dimension
()
{
return
NDimUp
;
}
// return safe value for vector length/stride, based on compile-time known only
// variables
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
,
const
LowVectorStrides
&
)
{
if
constexpr
(
NDimUp
>
0
)
{
array
<
index_t
,
NDimUp
>
up_vector_lengths
{
-
1
};
array
<
index_t
,
NDimUp
>
up_vector_strides
{
-
1
};
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
else
{
return
make_tuple
(
array
<
index_t
,
0
>
{},
array
<
index_t
,
0
>
{});
}
}
};
template
<
typename
LowLength
>
struct
pass_through
:
public
base_transform
<
1
,
1
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
pass_through
;
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}));
UpLengths
up_lengths_
;
CK_TILE_HOST_DEVICE
constexpr
pass_through
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
pass_through
(
const
LowLength
&
low_length
)
:
up_lengths_
{
make_tuple
(
low_length
)}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
pass_through
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
idx_diff_low
[
I0
]
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
{
return
make_tuple
(
low_vector_lengths
,
low_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"pass_through{"
);
//
printf
(
"up_lengths_:"
);
print
(
up_lengths_
);
//
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
LeftPadLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
pad
:
public
base_transform
<
1
,
1
>
{
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPadLength
{}
+
RightPadLength
{}));
UpLengths
up_lengths_
;
LeftPadLength
left_pad_length_
;
RightPadLength
right_pad_length_
;
CK_TILE_HOST_DEVICE
constexpr
pad
()
:
up_lengths_
{},
left_pad_length_
{},
right_pad_length_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
pad
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_length
,
const
RightPadLength
&
right_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad_length
+
right_pad_length
)},
left_pad_length_
{
left_pad_length
},
right_pad_length_
{
right_pad_length
}
{
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}]
-
left_pad_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
idx_diff_low
[
I0
]
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
((
idx_up
[
number
<
0
>
{}]
>=
left_pad_length_
)
&&
(
idx_up
[
number
<
0
>
{}]
<
up_lengths_
[
number
<
0
>
{}]
-
right_pad_length_
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
LeftPadLength
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightPadLength
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"pad{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"left_pad_length_: "
);
print
(
left_pad_length_
);
printf
(
", "
);
//
printf
(
"right_pad_length_: "
);
print
(
right_pad_length_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
LeftPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
left_pad
{
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPadLength
{}));
UpLengths
up_lengths_
;
LeftPadLength
left_pad_length_
;
CK_TILE_HOST_DEVICE
constexpr
left_pad
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
left_pad
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad_length
)},
left_pad_length_
{
left_pad_length
}
{
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}]
-
left_pad_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
idx_diff_low
[
I0
]
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
number
<
0
>
{}]
>=
left_pad_length_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
LeftPadLength
>::
value
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
{
// TODO: we allow pass through this vector length. If one need per-pixel check,
// should change the guaranteed vector length while creating the tensor view.
// It's up to runtime to check the padding length should be multiple of vector length
return
make_tuple
(
low_vector_lengths
,
low_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"left_pad{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"left_pad_length_: "
);
print
(
left_pad_length_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
right_pad
:
public
base_transform
<
1
,
1
>
{
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
RightPadLength
{}));
UpLengths
up_lengths_
;
LowLength
low_length_
;
RightPadLength
right_pad_length_
;
CK_TILE_HOST_DEVICE
constexpr
right_pad
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
right_pad
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
right_pad_length
)},
low_length_
{
low_length
},
right_pad_length_
{
right_pad_length
}
{
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
idx_diff_low
[
I0
]
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
number
<
0
>
{}]
<
low_length_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
LowLength
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightPadLength
>::
value
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
{
// TODO: we allow pass through this vector length. If one need per-pixel check,
// should change the guaranteed vector length while creating the tensor view.
// It's up to runtime to check the padding length should be multiple of vector length
return
make_tuple
(
low_vector_lengths
,
low_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"right_pad{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_pad_length_: "
);
print
(
right_pad_length_
);
printf
(
"}"
);
}
};
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
// 2) Tuple of number, which is known at compile-time, or
// 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially
// at compile-time
template
<
typename
UpLengths
,
typename
Coefficients
,
typename
std
::
enable_if
<
UpLengths
::
size
()
==
Coefficients
::
size
(),
bool
>
::
type
=
false
>
struct
embed
:
public
base_transform
<
1
,
UpLengths
::
size
()
>
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
size
();
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
NDimUp
>
;
UpLengths
up_lengths_
;
Coefficients
coefficients_
;
CK_TILE_HOST_DEVICE
constexpr
embed
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
embed
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
:
up_lengths_
{
up_lengths
},
coefficients_
{
coefficients
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
embed
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
0
;
static_for
<
0
,
NDimUp
,
1
>
{}([
&
idx_low
,
&
idx_up
,
this
](
auto
i
)
{
idx_low
(
number
<
0
>
{})
+=
idx_up
[
i
]
*
this
->
coefficients_
[
i
];
});
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
const
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
0
;
static_for
<
0
,
NDimUp
,
1
>
{}(
[
&
](
auto
i
)
{
idx_diff_low
(
number
<
0
>
{})
+=
idx_diff_up
[
i
]
*
coefficients_
[
i
];
});
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
Coefficients
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"embed{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"coefficients_: "
);
print
(
coefficients_
);
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
struct
lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
number
<
I
>
i
)
const
{
return
magic_division
::
calculate_magic_numbers
(
LowLengths
{}[
i
]);
}
};
// Implementation of "merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template
<
typename
LowLengths
>
struct
merge_v2_magic_division
:
public
base_transform
<
LowLengths
::
size
(),
1
>
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
size
();
using
LowerIndex
=
multi_index
<
NDimLow
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
multiplies
{},
number
<
1
>
{})));
using
LowLengthsMagicDivisor
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_divisor
<
LowLengths
>
{},
number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsMagicDivisor
low_lengths_magic_divisor_
;
UpLengths
up_lengths_
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
CK_TILE_HOST_DEVICE
constexpr
merge_v2_magic_division
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
merge_v2_magic_division
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_magic_divisor_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
magic_division
::
calculate_magic_numbers
(
low_lengths
[
i
]);
},
number
<
NDimLow
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
multiplies
{},
I1
))}
{
static_assert
(
LowerIndex
::
size
()
==
NDimLow
,
"wrong!"
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
merge
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
NDimLow
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
I0
];
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
tmp2
=
magic_division
::
do_magic_division
(
tmp
,
this
->
low_lengths_magic_divisor_
[
i
][
I0
],
this
->
low_lengths_magic_divisor_
[
i
][
I1
]);
idx_low
(
i
)
=
tmp
-
tmp2
*
this
->
low_lengths_
[
i
];
tmp
=
tmp2
;
});
idx_low
(
number
<
0
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
)
const
{
static_assert
(
LowIdxDiff
::
size
()
==
NDimLow
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
NDimLow
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up_new
[
number
<
0
>
{}];
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
tmp2
=
magic_division
::
do_magic_division
(
tmp
,
this
->
low_lengths_magic_divisor_
[
i
][
I0
],
this
->
low_lengths_magic_divisor_
[
i
][
I1
]);
index_t
idx_low_old
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
-
tmp2
*
this
->
low_lengths_
[
i
];
tmp
=
tmp2
;
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
idx_low_old
;
});
idx_diff_low
(
number
<
0
>
{})
=
tmp
-
idx_low
(
number
<
0
>
{});
idx_low
(
number
<
0
>
{})
=
tmp
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
LowLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
LowLengthsMagicDivisor
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
{
array
<
index_t
,
1
>
up_vector_lengths
{
-
1
};
array
<
index_t
,
1
>
up_vector_strides
{
-
1
};
up_vector_lengths
[
0
]
=
low_vector_lengths
[
number
<
NDimLow
-
1
>
{}];
up_vector_strides
[
0
]
=
low_vector_strides
[
number
<
NDimLow
-
1
>
{}];
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"merge_v2_magic_division{"
);
//
printf
(
"low_lengths_ "
);
print
(
low_lengths_
);
printf
(
", "
);
//
printf
(
"up_lengths_ "
);
print
(
up_lengths_
);
printf
(
"}"
);
}
};
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
merge_v3_division_mod
:
public
base_transform
<
LowLengths
::
size
(),
1
>
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
size
();
using
LowerIndex
=
multi_index
<
NDimLow
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
multiplies
{},
number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
multiplies
{},
number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
CK_TILE_HOST_DEVICE
constexpr
merge_v3_division_mod
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
merge_v3_division_mod
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
multiplies
{},
number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
multiplies
{},
number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
size
()
==
NDimLow
,
"wrong!"
);
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
NDimLow
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
)
const
{
static_assert
(
LowIdxDiff
::
size
()
==
NDimLow
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
NDimLow
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
INm1
=
number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
tmp2
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
tmp2
;
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
const
index_t
tmp2
=
idx_low
[
INm1
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_low
[
INm1
]
-
tmp2
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
LowLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
{
array
<
index_t
,
1
>
up_vector_lengths
{
-
1
};
array
<
index_t
,
1
>
up_vector_strides
{
-
1
};
up_vector_lengths
[
0
]
=
low_vector_lengths
[
number
<
NDimLow
-
1
>
{}];
up_vector_strides
[
0
]
=
low_vector_strides
[
number
<
NDimLow
-
1
>
{}];
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"Merge_v3_direct_division_mod{"
);
//
printf
(
"low_lengths_ "
);
print
(
low_lengths_
);
printf
(
", "
);
//
printf
(
"low_lengths_scan_ "
);
print
(
low_lengths_scan_
);
printf
(
", "
);
//
printf
(
"up_lengths_ "
);
print
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
>
struct
unmerge
:
public
base_transform
<
1
,
UpLengths
::
size
()
>
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
size
();
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
NDimUp
>
;
using
UpLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
UpLengths
{},
multiplies
{},
number
<
1
>
{}));
UpLengths
up_lengths_
;
UpLengthsScan
up_lengths_scan_
;
CK_TILE_HOST_DEVICE
constexpr
unmerge
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
unmerge
(
const
UpLengths
&
up_lengths
)
:
up_lengths_
{
up_lengths
},
up_lengths_scan_
{
container_reverse_exclusive_scan
(
up_lengths
,
multiplies
{},
number
<
1
>
{})}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
unmerge
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
if
constexpr
(
!
Use24BitIntegerCalculation
)
{
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
NDimUp
-
1
>
{}];
static_for
<
0
,
NDimUp
-
1
,
1
>
{}(
[
&
](
auto
i
)
{
idx_low
(
number
<
0
>
{})
+=
idx_up
[
i
]
*
up_lengths_scan_
[
i
];
});
}
else
{
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
NDimUp
-
1
>
{}];
static_for
<
0
,
NDimUp
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
number
<
0
>
{})
=
(
0x00ffffff
&
idx_low
[
number
<
0
>
{}])
+
(
0x00ffffff
&
idx_up
[
i
])
*
(
0x00ffffff
&
up_lengths_scan_
[
i
]);
});
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
const
{
calculate_lower_index
(
idx_diff_low
,
idx_diff_up
);
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
UpLengthsScan
>::
value
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
{
array
<
index_t
,
NDimUp
>
up_vector_lengths
{
-
1
};
array
<
index_t
,
NDimUp
>
up_vector_strides
{
-
1
};
constexpr
auto
up_length_last
=
UpLengths
{}[
number
<
NDimUp
-
1
>
{}];
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
decltype
(
up_length_last
)
>::
value
)
{
if
(
low_vector_lengths
[
0
]
!=
-
1
)
{
up_vector_lengths
(
NDimUp
-
1
)
=
gcd
(
low_vector_lengths
[
0
],
up_length_last
);
}
}
up_vector_strides
(
NDimUp
-
1
)
=
low_vector_strides
[
0
];
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"unmerge{"
);
//
printf
(
"up_lengths_"
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"up_lengths_scan_"
);
print
(
up_lengths_scan_
);
printf
(
"}"
);
}
};
template
<
typename
LowerIndex
>
struct
freeze
:
public
base_transform
<
1
,
0
>
{
LowerIndex
low_idx_
;
CK_TILE_HOST_DEVICE
constexpr
freeze
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
freeze
(
const
LowerIndex
&
low_idx
)
:
low_idx_
{
low_idx
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_upper_lengths
()
{
return
tuple
<>
{};
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up */
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
0
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
low_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
/* idx_diff_up */
,
LowIdx
&
/* idx_low */
,
const
UpIdx
&
/* idx_up_new */
)
{
idx_diff_low
(
number
<
0
>
{})
=
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
LowerIndex
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"freeze{"
);
//
printf
(
"low_idx_: "
);
print
(
low_idx_
);
printf
(
"}"
);
}
};
// insert a dangling upper dimension without lower dimension
template
<
typename
UpperLength
>
struct
insert
:
public
base_transform
<
0
,
1
>
{
using
UpLengths
=
decltype
(
make_tuple
(
UpperLength
{}));
UpLengths
up_lengths_
;
CK_TILE_HOST_DEVICE
constexpr
insert
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
insert
(
const
UpperLength
&
up_length
)
:
up_lengths_
{
make_tuple
(
up_length
)}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_lower_dimension
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_upper_dimension
()
{
return
1
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
,
const
UpIdx
&
)
const
{
static_assert
(
LowIdx
::
size
()
==
0
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
,
const
UpIdxDiff
&
,
LowIdx
&
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
0
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
0
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpperLength
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"insert{"
);
//
print
(
up_lengths_
);
printf
(
"}"
);
}
};
// replicate the original tensor and create a higher dimensional tensor
template
<
typename
UpLengths
>
struct
replicate
:
public
base_transform
<
0
,
UpLengths
::
size
()
>
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
size
();
CK_TILE_HOST_DEVICE
constexpr
replicate
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
replicate
(
const
UpLengths
&
up_lengths
)
:
up_lengths_
{
up_lengths
}
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
,
const
UpIdx
&
)
const
{
static_assert
(
LowIdx
::
size
()
==
0
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
,
const
UpIdxDiff
&
,
LowIdx
&
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
0
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
0
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"replicate{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
"}"
);
}
//
UpLengths
up_lengths_
;
};
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
struct
slice
:
public
base_transform
<
1
,
1
>
{
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
SliceEnd
{}
-
SliceBegin
{}));
UpLengths
up_lengths_
;
SliceBegin
slice_begin_
;
SliceEnd
slice_end_
;
CK_TILE_HOST_DEVICE
constexpr
slice
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
slice
(
const
LowLength
&
,
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
:
up_lengths_
{
make_tuple
(
slice_end
-
slice_begin
)},
slice_begin_
{
slice_begin
},
slice_end_
{
slice_end
}
{
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}]
+
slice_begin_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
idx_diff_low
[
I0
]
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
)
const
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
SliceBegin
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
SliceEnd
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"slice{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"slice_begin_: "
);
print
(
slice_begin_
);
printf
(
", "
);
//
printf
(
"slice_end_: "
);
print
(
slice_end_
);
printf
(
"}"
);
}
// namespace ck
};
// namespace ck
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template
<
typename
Modulus
,
typename
UpLength
>
struct
modulo
:
public
base_transform
<
1
,
1
>
{
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
Modulus
modulus_
;
UpLengths
up_lengths_
;
CK_TILE_HOST_DEVICE
constexpr
modulo
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
modulo
(
const
Modulus
&
modulus
,
const
UpLength
&
up_length
)
:
modulus_
{
modulus
},
up_lengths_
{
make_tuple
(
up_length
)}
{
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}]
%
modulus_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
up_idx
)
const
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
const
auto
idx_low_old
=
idx_low
;
idx_low
[
I0
]
=
(
up_idx
[
I0
]
+
idx_diff_up
[
I0
])
%
modulus_
;
idx_diff_low
[
I0
]
=
idx_low
-
idx_low_old
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"Modulus{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
"}"
);
}
};
// 2D XOR, NOTE: "xor" is a keyword
template
<
typename
LowLengths
,
typename
RightShift
>
struct
xor_t
:
public
base_transform
<
2
,
2
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
xor_t
;
using
LowerIndex
=
multi_index
<
2
>
;
using
UpperIndex
=
multi_index
<
2
>
;
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
RightShift
right_shift_
;
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{},
right_shift_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
:
up_lengths_
{
low_lengths
},
right_shift_
{
right_shift
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
xor_t
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
2
&&
UpIdx
::
size
()
==
2
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
const
auto
idx_low_1_tmp
=
(
idx_up
[
number
<
1
>
{}]
-
idx_up
[
number
<
0
>
{}]
*
right_shift_
)
%
up_lengths_
[
number
<
1
>
{}];
const
auto
idx_low_1
=
(
idx_low_1_tmp
>=
0
)
?
idx_low_1_tmp
:
up_lengths_
[
number
<
1
>
{}]
+
idx_low_1_tmp
;
idx_low
(
number
<
1
>
{})
=
idx_low_1
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdxDiff
::
size
()
==
2
&&
UpIdxDiff
::
size
()
==
2
&&
LowIdx
::
size
()
==
2
&&
UpIdx
::
size
()
==
2
,
"wrong! inconsistent # of dimension"
);
const
auto
idx_low_old
=
idx_low
;
calculate_lower_index
(
idx_low
,
idx_up
);
idx_diff_low
=
idx_low
-
idx_low_old
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
;
}
// MUST be static function
template
<
typename
LowVectorLengths
,
typename
LowVectorStrides
>
CK_TILE_HOST_DEVICE
constexpr
auto
calculate_upper_dimension_safe_vector_length_strides
(
const
LowVectorLengths
&
low_vector_lengths
,
const
LowVectorStrides
&
low_vector_strides
)
const
{
array
<
index_t
,
2
>
up_vector_lengths
=
low_vector_lengths
;
array
<
index_t
,
2
>
up_vector_strides
=
low_vector_strides
;
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
)
{
if
(
low_vector_lengths
[
1
]
!=
-
1
)
{
up_vector_lengths
(
1
)
=
gcd
(
low_vector_lengths
[
1
],
abs
(
right_shift_
));
}
}
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"xor_t{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_shift_: "
);
print
(
right_shift_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
OffsetLength
>
struct
offset
:
public
base_transform
<
1
,
1
>
{
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}));
UpLengths
up_lengths_
;
OffsetLength
offset_length_
;
CK_TILE_HOST_DEVICE
constexpr
offset
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
offset
(
const
LowLength
&
low_length
,
const
OffsetLength
&
offset_length
)
:
up_lengths_
{
make_tuple
(
low_length
)},
offset_length_
{
offset_length
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
offset
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}]
+
offset_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
)
{
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
number
<
0
>
{};
idx_diff_low
[
I0
]
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
)
const
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
OffsetLength
>::
value
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"offset{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"offset_length_: "
);
print
(
offset_length_
);
printf
(
"}"
);
}
};
//*******************************************************************************************************
template
<
typename
LowLength
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_pass_through_transform
(
const
LowLength
&
low_length
)
{
return
pass_through
<
LowLength
>
{
low_length
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
typename
RightPad
,
bool
SkipIsValidCheck
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPad
&
left_pad
,
const
RightPad
&
right_pad
,
bool_constant
<
SkipIsValidCheck
>
=
bool_constant
<
false
>
{})
{
return
pad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
}
template
<
typename
LowLength
,
typename
LeftPadLength
,
bool
SkipIsValidCheck
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_left_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_
,
bool_constant
<
SkipIsValidCheck
>
=
bool_constant
<
false
>
{})
{
return
left_pad
<
LowLength
,
LeftPadLength
,
SkipIsValidCheck
>
{
low_length
,
left_pad_
};
}
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad_
,
bool_constant
<
SkipIsValidCheck
>
=
bool_constant
<
false
>
{})
{
return
right_pad
<
LowLength
,
RightPadLength
,
SkipIsValidCheck
>
{
low_length
,
right_pad_
};
}
template
<
typename
UpLengths
,
typename
Coefficients
,
typename
std
::
enable_if
<
UpLengths
::
size
()
==
Coefficients
::
size
(),
bool
>
::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_embed_transform
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
{
return
embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
}
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
{
return
merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_merge_transform_v3_division_mod
(
const
LowLengths
&
low_lengths
)
{
return
merge_v3_division_mod
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
return
make_merge_transform_v2_magic_division
(
low_lengths
);
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_unmerge_transform
(
const
UpLengths
&
up_lengths
,
bool_constant
<
Use24BitIntegerCalculation
>
=
bool_constant
<
false
>
{})
{
return
unmerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
}
template
<
typename
LowerIndex
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_freeze_transform
(
const
LowerIndex
&
low_idx
)
{
return
freeze
<
LowerIndex
>
{
low_idx
};
}
template
<
typename
UpperIndex
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_insert_transform
(
const
UpperIndex
&
up_idx
)
{
return
insert
<
UpperIndex
>
{
up_idx
};
}
template
<
typename
UpLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_replicate_transform
(
const
UpLengths
&
up_lengths
)
{
return
replicate
<
UpLengths
>
{
up_lengths
};
}
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_slice_transform
(
const
LowLength
&
low_length
,
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
{
return
slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
}
template
<
typename
Modulus
,
typename
UpLength
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_modulo_transform
(
const
Modulus
&
modulus
,
const
UpLength
&
up_length
)
{
return
modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
,
typename
RightShift
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
{
return
xor_t
<
LowLengths
,
RightShift
>
{
low_lengths
,
right_shift
};
}
template
<
typename
LowLength
,
typename
OffsetLength
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_offset_transform
(
const
LowLength
&
low_length
,
const
OffsetLength
&
offset_length
)
{
return
offset
<
LowLength
,
OffsetLength
>
{
low_length
,
offset_length
};
}
}
// namespace ck_tile
include/ck_tile/core/algorithm/space_filling_curve.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
TensorLengths
,
typename
DimAccessOrder
,
typename
ScalarsPerAccess
,
bool
SnakeCurved
=
true
>
// # of scalars per access in each dimension
struct
space_filling_curve
{
static
constexpr
index_t
TensorSize
=
reduce_on_sequence
(
TensorLengths
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
0
<
TensorSize
,
"space_filling_curve should be used to access a non-empty tensor"
);
static
constexpr
index_t
nDim
=
TensorLengths
::
size
();
using
Index
=
multi_index
<
nDim
>
;
static
constexpr
index_t
ScalarPerVector
=
reduce_on_sequence
(
ScalarsPerAccess
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
auto
access_lengths
=
TensorLengths
{}
/
ScalarsPerAccess
{};
static
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
static
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static
constexpr
auto
to_index_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
ordered_access_lengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{}),
make_tuple
(
sequence
<
0
>
{}));
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
static_assert
(
TensorLengths
::
size
()
==
ScalarsPerAccess
::
size
());
static_assert
(
TensorLengths
{}
%
ScalarsPerAccess
{}
==
typename
uniform_sequence_gen
<
TensorLengths
::
size
(),
0
>::
type
{});
return
reduce_on_sequence
(
TensorLengths
{},
multiplies
{},
number
<
1
>
{})
/
ScalarPerVector
;
}
template
<
index_t
AccessIdx1dHead
,
index_t
AccessIdx1dTail
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_step_between
(
number
<
AccessIdx1dHead
>
,
number
<
AccessIdx1dTail
>
)
{
static_assert
(
AccessIdx1dHead
>=
0
&&
AccessIdx1dHead
<
get_num_of_access
(),
"1D index out of range"
);
static_assert
(
AccessIdx1dTail
>=
0
&&
AccessIdx1dTail
<
get_num_of_access
(),
"1D index out of range"
);
constexpr
auto
idx_head
=
get_index
(
number
<
AccessIdx1dHead
>
{});
constexpr
auto
idx_tail
=
get_index
(
number
<
AccessIdx1dTail
>
{});
return
idx_tail
-
idx_head
;
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_forward_step
(
number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
<
get_num_of_access
(),
"1D index should be larger than 0"
);
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_backward_step
(
number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
>
0
,
"1D index should be larger than 0"
);
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
-
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
Index
get_index
(
number
<
AccessIdx1d
>
)
{
#if 0
/*
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
#else
constexpr
auto
access_strides
=
container_reverse_exclusive_scan
(
ordered_access_lengths
,
multiplies
{},
number
<
1
>
{});
constexpr
auto
idx_1d
=
number
<
AccessIdx1d
>
{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
{
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
auto
res
=
idx_1d
.
value
;
auto
id
=
0
;
static_for
<
0
,
jdim
.
value
+
1
,
1
>
{}([
&
](
auto
kdim
)
{
id
=
res
/
access_strides
[
kdim
].
value
;
res
-=
id
*
access_strides
[
kdim
].
value
;
});
return
id
;
};
constexpr
auto
id
=
compute_index_impl
(
idim
);
return
number
<
id
>
{};
};
constexpr
auto
ordered_access_idx
=
generate_tuple
(
compute_index
,
number
<
nDim
>
{});
#endif
constexpr
auto
forward_sweep
=
[
&
]()
{
statically_indexed_array
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
idim
,
1
>
{}(
[
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
idim
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate multi-dim tensor index
auto
idx_md
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
ordered_idx
(
idim
)
=
!
SnakeCurved
||
forward_sweep
[
idim
]
?
ordered_access_idx
[
idim
]
:
ordered_access_lengths
[
idim
]
-
1
-
ordered_access_idx
[
idim
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
ScalarsPerAccess
{};
}();
return
idx_md
;
}
// FIXME: rename this function
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index_tuple_of_number
(
number
<
AccessIdx1d
>
)
{
constexpr
auto
idx
=
get_index
(
number
<
AccessIdx1d
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
number
<
idx
[
i
]
>
{};
},
number
<
nDim
>
{});
}
};
}
// namespace ck_tile
include/ck_tile/core/arch/amd_buffer_addressing.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
struct
__attribute__
((
packed
))
buffer_resource
{
const
void
*
ptr
;
uint32_t
range
;
uint32_t
config
;
};
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
{
buffer_resource
res
{
ptr
,
size
,
CK_TILE_BUFFER_RESOURCE_3RD_DWORD
};
return
__builtin_bit_cast
(
int32x4_t
,
res
);
}
// TODO: glc/slc/...
template
<
index_t
bytes
>
struct
buffer_load
;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template
<
>
struct
buffer_load
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_load_dword %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
index_t
bytes
>
struct
buffer_load_if
;
template
<
>
struct
buffer_load_if
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
{
static_assert
(
sizeof
(
T
)
==
16
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x4_t
;
static_assert
(
sizeof
(
mbuf_t
)
==
sizeof
(
T
));
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
{
static_assert
(
sizeof
(
T
)
==
8
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
template
<
index_t
bytes
>
struct
buffer_store
;
template
<
>
struct
buffer_store
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
2
);
using
mbuf_t
=
short
;
asm
volatile
(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
index_t
bytes
>
struct
buffer_store_if
;
template
<
>
struct
buffer_store_if
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
16
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store_if
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
8
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using
mbuf_t
=
ext_vector_t
<
typename
T
::
value_type
,
T
::
size
()
>
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store_if
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store_if
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
2
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
short
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_store_short %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_store_if
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
:
"memory"
);
}
};
CK_TILE_DEVICE
void
buffer_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
// clang-format off
namespace
impl
{
// can't use "+v" since there could be potential extra move(read/write)
// use "v" can help remove such duplicated moves
// besides, fake this as "memory" operation to force later valu after this fence
// TODO: may have scratch (because this is memory?)
// need to reduce extra move inside compiler
template
<
index_t
N
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
(
array
<
float
,
N
>&
b
)
{
static_for
<
0
,
b
.
size
(),
1
>
{}([
&
](
auto
i
){
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
i
))
:
"memory"
);
});
}
#if 1
// below specialization just merge size() of dwords into single section
template
<
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
<
2
>
(
array
<
float
,
2
>&
b
)
{
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
0
>
{})),
"v"
(
b
.
get
(
number
<
1
>
{}))
:
"memory"
);
}
template
<
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
<
3
>
(
array
<
float
,
3
>&
b
)
{
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
0
>
{})),
"v"
(
b
.
get
(
number
<
1
>
{})),
"v"
(
b
.
get
(
number
<
2
>
{}))
:
"memory"
);
}
template
<
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
<
4
>
(
array
<
float
,
4
>&
b
)
{
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
0
>
{})),
"v"
(
b
.
get
(
number
<
1
>
{})),
"v"
(
b
.
get
(
number
<
2
>
{})),
"v"
(
b
.
get
(
number
<
3
>
{}))
:
"memory"
);
}
template
<
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
<
8
>
(
array
<
float
,
8
>&
b
)
{
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
0
>
{})),
"v"
(
b
.
get
(
number
<
1
>
{})),
"v"
(
b
.
get
(
number
<
2
>
{})),
"v"
(
b
.
get
(
number
<
3
>
{})),
"v"
(
b
.
get
(
number
<
4
>
{})),
"v"
(
b
.
get
(
number
<
5
>
{})),
"v"
(
b
.
get
(
number
<
6
>
{})),
"v"
(
b
.
get
(
number
<
7
>
{}))
:
"memory"
);
}
template
<
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
<
16
>
(
array
<
float
,
16
>&
b
)
{
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
0
>
{})),
"v"
(
b
.
get
(
number
<
1
>
{})),
"v"
(
b
.
get
(
number
<
2
>
{})),
"v"
(
b
.
get
(
number
<
3
>
{})),
"v"
(
b
.
get
(
number
<
4
>
{})),
"v"
(
b
.
get
(
number
<
5
>
{})),
"v"
(
b
.
get
(
number
<
6
>
{})),
"v"
(
b
.
get
(
number
<
7
>
{})),
"v"
(
b
.
get
(
number
<
8
>
{})),
"v"
(
b
.
get
(
number
<
9
>
{})),
"v"
(
b
.
get
(
number
<
10
>
{})),
"v"
(
b
.
get
(
number
<
11
>
{})),
"v"
(
b
.
get
(
number
<
12
>
{})),
"v"
(
b
.
get
(
number
<
13
>
{})),
"v"
(
b
.
get
(
number
<
14
>
{})),
"v"
(
b
.
get
(
number
<
15
>
{}))
:
"memory"
);
}
template
<
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
<
32
>
(
array
<
float
,
32
>&
b
)
{
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
0
>
{})),
"v"
(
b
.
get
(
number
<
1
>
{})),
"v"
(
b
.
get
(
number
<
2
>
{})),
"v"
(
b
.
get
(
number
<
3
>
{})),
"v"
(
b
.
get
(
number
<
4
>
{})),
"v"
(
b
.
get
(
number
<
5
>
{})),
"v"
(
b
.
get
(
number
<
6
>
{})),
"v"
(
b
.
get
(
number
<
7
>
{})),
"v"
(
b
.
get
(
number
<
8
>
{})),
"v"
(
b
.
get
(
number
<
9
>
{})),
"v"
(
b
.
get
(
number
<
10
>
{})),
"v"
(
b
.
get
(
number
<
11
>
{})),
"v"
(
b
.
get
(
number
<
12
>
{})),
"v"
(
b
.
get
(
number
<
13
>
{})),
"v"
(
b
.
get
(
number
<
14
>
{})),
"v"
(
b
.
get
(
number
<
15
>
{})),
"v"
(
b
.
get
(
number
<
16
>
{})),
"v"
(
b
.
get
(
number
<
17
>
{})),
"v"
(
b
.
get
(
number
<
18
>
{})),
"v"
(
b
.
get
(
number
<
19
>
{})),
"v"
(
b
.
get
(
number
<
20
>
{})),
"v"
(
b
.
get
(
number
<
21
>
{})),
"v"
(
b
.
get
(
number
<
22
>
{})),
"v"
(
b
.
get
(
number
<
23
>
{})),
"v"
(
b
.
get
(
number
<
24
>
{})),
"v"
(
b
.
get
(
number
<
25
>
{})),
"v"
(
b
.
get
(
number
<
26
>
{})),
"v"
(
b
.
get
(
number
<
27
>
{})),
"v"
(
b
.
get
(
number
<
28
>
{})),
"v"
(
b
.
get
(
number
<
29
>
{})),
"v"
(
b
.
get
(
number
<
30
>
{})),
"v"
(
b
.
get
(
number
<
31
>
{}))
:
"memory"
);
}
#endif
CK_TILE_DEVICE
void
insert_dummy_dep
()
{}
template
<
typename
T
>
CK_TILE_DEVICE
void
insert_dummy_dep
(
T
&
buffer
)
{
// TODO: indeed we expect T to be multiple of dword. subdword is always buggy
using
da_type
=
array
<
float
,
(
sizeof
(
T
)
+
3
)
/
4
>
;
auto
&
dummy
=
reinterpret_cast
<
da_type
&>
(
buffer
);
insert_dummy_dep_per_dword
(
dummy
);
}
template
<
typename
Tx
,
typename
...
Ty
>
CK_TILE_DEVICE
void
insert_dummy_dep
(
Tx
&
bx
,
Ty
&
...
by
)
{
insert_dummy_dep
(
bx
);
insert_dummy_dep
(
by
...);
}
}
// clang-format on
template
<
typename
...
T
>
CK_TILE_DEVICE
void
buffer_load_fence
(
index_t
cnt
=
0
,
T
&
...
o
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
impl
::
insert_dummy_dep
(
o
...);
}
CK_TILE_DEVICE
void
buffer_store_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
// buffer load i8
CK_TILE_DEVICE_EXTERN
int8_t
llvm_amdgcn_raw_buffer_load_i8
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i8"
);
CK_TILE_DEVICE_EXTERN
int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2i8"
);
CK_TILE_DEVICE_EXTERN
int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i8"
);
// buffer load i16
CK_TILE_DEVICE_EXTERN
int16_t
llvm_amdgcn_raw_buffer_load_i16
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i16"
);
CK_TILE_DEVICE_EXTERN
int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2i16"
);
CK_TILE_DEVICE_EXTERN
int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i16"
);
// buffer load i32
CK_TILE_DEVICE_EXTERN
int32_t
llvm_amdgcn_raw_buffer_load_i32
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.i32"
);
CK_TILE_DEVICE_EXTERN
int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2i32"
);
CK_TILE_DEVICE_EXTERN
int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4i32"
);
// buffer load fp16
CK_TILE_DEVICE_EXTERN
_Float16
llvm_amdgcn_raw_buffer_load_fp16
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.f16"
);
CK_TILE_DEVICE_EXTERN
fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2f16"
);
CK_TILE_DEVICE_EXTERN
fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4f16"
);
// buffer load fp32
CK_TILE_DEVICE_EXTERN
float
llvm_amdgcn_raw_buffer_load_fp32
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.f32"
);
CK_TILE_DEVICE_EXTERN
fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v2f32"
);
CK_TILE_DEVICE_EXTERN
fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4
(
int32x4_t
srsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.load.v4f32"
);
// buffer store i8
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i8
(
int8_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i8"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i8x2
(
int8x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i8"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i8x4
(
int8x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i8"
);
// buffer store i16
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i16
(
int16_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i16x2
(
int16x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i16x4
(
int16x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i16"
);
// buffer store i32
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i32
(
int32_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i32"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i32x2
(
int32x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i32"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i32x4
(
int32x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i32"
);
// buffer store fp16
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_fp16
(
_Float16
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.f16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_fp16x2
(
fp16x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2f16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_fp16x4
(
fp16x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f16"
);
// buffer store fp32
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_fp32
(
float
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.f32"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_fp32x2
(
fp32x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2f32"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_fp32x4
(
fp32x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4f32"
);
// buffer atomic-add fp16
CK_TILE_DEVICE_EXTERN
fp16x2_t
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
fp16x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"
);
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN
int32_t
llvm_amdgcn_raw_buffer_atomic_add_i32
(
int32_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.add.i32"
);
// buffer atomic-add fp32
CK_TILE_DEVICE_EXTERN
float
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
float
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
// buffer atomic-max fp64
CK_TILE_DEVICE_EXTERN
double
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
int32x4_t
rsrc
,
// dst_wave_buffer_resource
int
voffset
,
// dst_thread_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
CK_TILE_DEVICE
void
async_buffer_load_dword
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
ioffset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
{
asm
volatile
(
"buffer_load_dword %1, %2, %3 offen offset:%4 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"s"
(
soffset
),
"n"
(
ioffset
)
:
"memory"
);
}
CK_TILE_DEVICE
void
async_buffer_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
// memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
// e.g. for
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
// page 67~68
enum
struct
amd_buffer_coherence_enum
{
coherence_default
=
0
,
// default value
glc
=
1
,
slc
=
2
,
glc_slc
=
3
,
};
template
<
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
thread_buffer
<
int8_t
,
N
>
amd_buffer_load_impl_with_bytes
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
using
rtn_type
=
thread_buffer
<
int8_t
,
N
>
;
if
constexpr
(
N
==
1
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_i8
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
2
)
{
int16_t
tmp
=
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
rtn_type
>
(
tmp
);
}
else
if
constexpr
(
N
==
4
)
{
int32_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
rtn_type
>
(
tmp
);
}
else
if
constexpr
(
N
==
8
)
{
int32x2_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
rtn_type
>
(
tmp
);
}
else
if
constexpr
(
N
==
16
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
rtn_type
>
(
tmp
);
}
else
if
constexpr
(
N
==
32
)
{
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
thread_buffer
<
int32_t
,
8
>
tmp
;
tmp
.
template
get_as
<
int32x4_t
>()(
number
<
0
>
{})
=
tmp0
;
tmp
.
template
get_as
<
int32x4_t
>()(
number
<
1
>
{})
=
tmp1
;
return
bit_cast
<
rtn_type
>
(
tmp
);
}
else
if
constexpr
(
N
==
64
)
{
int32x4_t
tmp0
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp1
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp2
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
int32x4_t
tmp3
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
int32_t
),
static_cast
<
index_t
>
(
coherence
));
thread_buffer
<
int32_t
,
16
>
tmp
;
tmp
.
template
get_as
<
int32x4_t
>()(
number
<
0
>
{})
=
tmp0
;
tmp
.
template
get_as
<
int32x4_t
>()(
number
<
1
>
{})
=
tmp1
;
tmp
.
template
get_as
<
int32x4_t
>()(
number
<
2
>
{})
=
tmp2
;
tmp
.
template
get_as
<
int32x4_t
>()(
number
<
3
>
{})
=
tmp3
;
return
bit_cast
<
rtn_type
>
(
tmp
);
}
}
#ifndef BUFFER_LOAD_USE_INLINEASM
#define BUFFER_LOAD_USE_INLINEASM 0
#endif
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
thread_buffer
<
T
,
N
>
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
)
{
static_assert
(
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
using
rtn_type
=
thread_buffer
<
T
,
N
>
;
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
// fp32
{
if
constexpr
(
N
==
1
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_fp32
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
2
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_fp32x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
4
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
8
)
{
thread_buffer
<
float
,
8
>
tmp
;
tmp
.
template
get_as
<
fp32x4_t
>()(
number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
template
get_as
<
fp32x4_t
>()(
number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
;
}
else
if
constexpr
(
N
==
16
)
{
thread_buffer
<
float
,
16
>
tmp
;
tmp
.
template
get_as
<
fp32x4_t
>()(
number
<
0
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
tmp
.
template
get_as
<
fp32x4_t
>()(
number
<
1
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
template
get_as
<
fp32x4_t
>()(
number
<
2
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
8
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
tmp
.
template
get_as
<
fp32x4_t
>()(
number
<
3
>
{})
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
+
12
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
return
tmp
;
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
// fp16
{
if
constexpr
(
N
==
1
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_fp16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
2
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_fp16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
4
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_fp16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
8
)
{
// use fp32 load to mimic fp16 load
fp32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_fp32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
rtn_type
>
(
tmp
);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf16_t
>::
value
)
// bf16
{
if
constexpr
(
N
==
1
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_i16
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
2
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_i16x2
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
4
)
{
return
bit_cast
<
rtn_type
>
(
llvm_amdgcn_raw_buffer_load_i16x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
)));
}
else
if
constexpr
(
N
==
8
)
{
int32x4_t
tmp
=
llvm_amdgcn_raw_buffer_load_i32x4
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
return
bit_cast
<
rtn_type
>
(
tmp
);
}
}
else
// other datatype
{
auto
raw_data
=
amd_buffer_load_impl_with_bytes
<
sizeof
(
T
)
*
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
);
return
bit_cast
<
rtn_type
>
(
raw_data
);
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_buffer_load_raw_impl
(
thread_buffer
<
T
,
N
>&
dst
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
flag
=
0
)
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
static_assert
(
bytes
==
1
||
bytes
==
2
||
bytes
==
4
||
bytes
==
8
||
bytes
==
16
,
"wrong! not supported by buffer_load instruction"
);
using
type
=
thread_buffer
<
T
,
N
>
;
if
constexpr
(
oob_conditional_check
)
{
buffer_load_if
<
sizeof
(
type
)
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
}
else
{
buffer_load
<
sizeof
(
type
)
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_async_buffer_load_impl
(
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
)
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
async_buffer_load_dword
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
);
}
template
<
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
||
N
==
32
||
N
==
64
,
"wrong! not implemented"
);
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i8
(
bit_cast
<
int8_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i32x2
(
bit_cast
<
int32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
16
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
bit_cast
<
int32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
32
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
.
template
get_as
<
int32x4_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
.
template
get_as
<
int32x4_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
64
)
{
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
.
template
get_as
<
int32x4_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
.
template
get_as
<
int32x4_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
4
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
.
template
get_as
<
int32x4_t
>()[
number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
8
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i32x4
(
src_thread_data
.
template
get_as
<
int32x4_t
>()[
number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
)
*
12
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_buffer_store_impl
(
const
thread_buffer
<
T
,
N
>
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
(
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
// fp32
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp32
(
bit_cast
<
float
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp32x2
(
bit_cast
<
fp32x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
fp32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_fp32x4
(
src_thread_data
.
template
get_as
<
fp32x4_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_fp32x4
(
src_thread_data
.
template
get_as
<
fp32x4_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
// fp16
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_fp16
(
bit_cast
<
_Float16
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_fp16x2
(
bit_cast
<
fp16x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_fp16x4
(
bit_cast
<
fp16x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
#if 0
thread_buffer<fp16_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(fp16_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4
(
bit_cast
<
fp32x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
#endif
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf16_t
>::
value
)
// bf16
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_i16
(
bit_cast
<
int16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_i16x2
(
bit_cast
<
int16x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
bit_cast
<
int16x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
.
template
get_as
<
int16x4_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_i16x4
(
src_thread_data
.
template
get_as
<
int16x4_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
bf16_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
{
using
r_t
=
thread_buffer
<
int8_t
,
sizeof
(
T
)
*
N
>
;
amd_buffer_store_impl_with_bytes
<
sizeof
(
T
)
*
N
,
coherence
>
(
bit_cast
<
r_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
);
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_buffer_store_raw_impl
(
const
thread_buffer
<
T
,
N
>&
dst_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
is_valid_element
=
1
)
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
static_assert
(
bytes
==
1
||
bytes
==
2
||
bytes
==
4
||
bytes
==
8
||
bytes
==
16
,
"wrong! not supported by buffer_store instruction"
);
using
type
=
thread_buffer
<
T
,
N
>
;
if
constexpr
(
oob_conditional_check
)
{
buffer_store_if
<
sizeof
(
type
)
>
{}(
dst_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
,
is_valid_element
);
}
else
{
buffer_store
<
sizeof
(
type
)
>
{}(
dst_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
}
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
amd_buffer_atomic_add_impl
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
((
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
bit_cast
<
float
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
src_thread_data
.
template
get_as
<
float
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
src_thread_data
.
template
get_as
<
float
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
float
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
src_thread_data
.
template
get_as
<
float
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
src_thread_data
.
template
get_as
<
float
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
float
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
src_thread_data
.
template
get_as
<
float
>()[
number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
float
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_fp32
(
src_thread_data
.
template
get_as
<
float
>()[
number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
float
),
0
);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp16_t
>::
value
)
{
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
bit_cast
<
fp16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
4
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
src_thread_data
.
template
get_as
<
fp16x2_t
>()[
i
],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
i
*
sizeof
(
fp16x2_t
),
0
);
});
}
else
if
constexpr
(
N
==
8
)
{
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
src_thread_data
.
template
get_as
<
fp16x2_t
>()[
i
],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
i
*
sizeof
(
fp16x2_t
),
0
);
});
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
int32_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_add_i32
(
bit_cast
<
int32_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
.
template
get_as
<
int32_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
.
template
get_as
<
int32_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
.
template
get_as
<
int32_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
.
template
get_as
<
int32_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
int32_t
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
.
template
get_as
<
int32_t
>()[
number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
int32_t
),
0
);
llvm_amdgcn_raw_buffer_atomic_add_i32
(
src_thread_data
.
template
get_as
<
int32_t
>()[
number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
int32_t
),
0
);
}
}
}
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
amd_buffer_atomic_max_impl
(
const
thread_buffer
<
T
,
N
>
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
((
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
if
constexpr
(
std
::
is_same
<
T
,
double
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
bit_cast
<
double
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
.
template
get_as
<
double
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
.
template
get_as
<
double
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
.
template
get_as
<
double
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
.
template
get_as
<
double
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
.
template
get_as
<
double
>()[
number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
.
template
get_as
<
double
>()[
number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
double
),
0
);
}
}
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
// oob_conditional_check : dynamic check if out-of-bound
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
thread_buffer
<
T
,
N
>
amd_buffer_load_invalid_element_return_zero
(
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
bool
src_thread_element_valid
,
index_t
src_element_space_size
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
[
&
]()
{
if
constexpr
(
oob_conditional_check
)
return
src_thread_element_valid
?
0
:
0x80000000
;
else
return
0
;
}();
return
amd_buffer_load_impl
<
T
,
N
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#else
thread_buffer
<
T
,
N
>
tmp
=
amd_buffer_load_impl
<
T
,
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
if
constexpr
(
oob_conditional_check
)
return
src_thread_element_valid
?
tmp
:
thread_buffer
<
T
,
N
>
{
numeric
<
T
>::
zero
()};
else
return
tmp
;
#endif
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
thread_buffer
<
T
,
N
>
amd_buffer_load_invalid_element_return_customized_value
(
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
bool
src_thread_element_valid
,
index_t
src_element_space_size
,
T
customized_value
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
thread_buffer
<
T
,
N
>
tmp
=
amd_buffer_load_impl
<
T
,
N
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
if
constexpr
(
oob_conditional_check
)
return
src_thread_element_valid
?
tmp
:
thread_buffer
<
T
,
N
>
{
customized_value
};
else
return
tmp
;
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
,
index_t
is_valid_element
=
0
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
);
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
// ... unless people manually write zero to LDS at the proper address.
// so not support invalid_element check for now.
// buffer_load OOB still working.
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
T
*
smem
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
)
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
);
}
// buffer_store requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_buffer_store
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
[
&
]()
{
if
constexpr
(
oob_conditional_check
)
return
dst_thread_element_valid
?
0
:
0x80000000
;
else
return
0
;
}();
amd_buffer_store_impl
<
T
,
N
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
constexpr
(
oob_conditional_check
)
{
if
(
dst_thread_element_valid
)
{
amd_buffer_store_impl
<
T
,
N
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
}
else
{
amd_buffer_store_impl
<
T
,
N
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_thread_element_valid
);
}
// buffer_atomic_add requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
amd_buffer_atomic_add
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
amd_buffer_atomic_add_impl
<
T
,
N
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_add_impl
<
T
,
N
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
amd_buffer_atomic_max
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
amd_buffer_atomic_max_impl
<
T
,
N
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_max_impl
<
T
,
N
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
typename
T
,
index_t
NumElemsPerThread
>
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
T
*
lds_base_ptr
,
const
index_t
lds_offset
,
const
bool
is_valid
,
const
index_t
src_element_space_size
)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr
auto
dword_bytes
=
4
;
constexpr
auto
bytes_per_thread
=
sizeof
(
T
)
*
NumElemsPerThread
;
static_assert
(
bytes_per_thread
==
dword_bytes
);
const
uint32_t
*
global_ptr
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
global_base_ptr
));
const
int32x4_t
src_resource
=
make_wave_buffer_resource
(
global_ptr
,
src_element_space_size
*
sizeof
(
T
));
const
index_t
global_offset_bytes
=
is_valid
?
global_offset
*
sizeof
(
T
)
:
0x80000000
;
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T
*
lds_ptr
=
lds_base_ptr
+
lds_offset
;
auto
const
lds_ptr_sgpr
=
__builtin_amdgcn_readfirstlane
((
reinterpret_cast
<
uintptr_t
>
(
lds_ptr
)));
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
reinterpret_cast
<
__attribute__
((
address_space
(
3
)))
uint32_t
*>
(
reinterpret_cast
<
uintptr_t
>
(
lds_base_ptr
+
lds_offset
));
llvm_amdgcn_raw_buffer_load_lds
(
src_resource
,
lds_ptr
,
sizeof
(
uint32_t
),
global_offset_bytes
,
0
,
0
,
0
);
#endif
}
}
// namespace ck_tile
include/ck_tile/core/arch/arch.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace
ck_tile
{
enum
struct
address_space_enum
{
generic
,
global
,
lds
,
sgpr
,
vgpr
,
};
enum
struct
memory_operation_enum
{
set
,
atomic_add
,
atomic_max
,
add
};
CK_TILE_HOST_DEVICE
constexpr
index_t
get_warp_size
()
{
// warpSize is defined by HIP
return
warpSize
;
}
CK_TILE_DEVICE
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
CK_TILE_DEVICE
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
// TODO: deprecate these
CK_TILE_DEVICE
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
CK_TILE_DEVICE
index_t
get_thread_global_1d_id
()
{
return
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
}
CK_TILE_DEVICE
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
// Use these instead
CK_TILE_DEVICE
index_t
get_lane_id
()
{
return
__lane_id
();
}
CK_TILE_DEVICE
index_t
get_warp_id
()
{
return
__builtin_amdgcn_readfirstlane
(
threadIdx
.
x
/
get_warp_size
());
}
CK_TILE_DEVICE
index_t
get_thread_id
()
{
return
threadIdx
.
x
;
}
CK_TILE_DEVICE
index_t
get_block_id
()
{
return
blockIdx
.
x
;
}
CK_TILE_DEVICE
void
block_sync_lds
()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
#else
__syncthreads
();
#endif
}
CK_TILE_DEVICE
void
block_sync_lds_direct_load
()
{
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
}
CK_TILE_DEVICE
void
s_nop
()
{
#if 1
asm
volatile
(
"\
s_nop 0
\n
\
"
::
);
#else
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
// namespace ck_tile
include/ck_tile/core/arch/utility.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <stdint.h>
namespace
ck_tile
{
// TODO: we have "memory" clobber here because this inline asm is used for async copy
CK_TILE_DEVICE
void
m0_set_with_memory
(
index_t
v
)
{
asm
volatile
(
"s_mov_b32 m0, %0"
:
:
"s"
(
v
)
:
"memory"
);
}
// NOTE: this is an immediate value
CK_TILE_DEVICE
void
m0_inc_with_memory
(
index_t
v
)
{
asm
volatile
(
"s_add_u32 m0, %0, m0"
:
:
"n"
(
v
)
:
"memory"
);
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle_up
(
const
T
&
v_local
,
uint32_t
lane_delta
)
{
#if 0
return __shfl_up(v_local, lane_delta);
#elif
1
static_assert
(
sizeof
(
T
)
==
sizeof
(
int32_t
),
"wrong!"
);
const
uint32_t
wrap_around_lane_delta
=
warpSize
-
lane_delta
;
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
wrap_around_lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
#endif
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle_down
(
const
T
&
v_local
,
uint32_t
lane_delta
)
{
#if 0
return __shfl_down(v_local, lane_delta);
#elif
1
static_assert
(
sizeof
(
T
)
==
sizeof
(
int32_t
),
"wrong!"
);
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
(
__lane_id
()
<<
2
)
+
(
lane_delta
<<
2
),
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
#endif
}
}
// namespace ck_tile
include/ck_tile/core/config.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#endif
#ifdef __HIPCC__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
#endif
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
#endif
// in the old rocm period, we have to use tuple array implementation to implement this
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
#endif
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
#endif
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
// if using tuple-array as thread_buffer implementation, need to support {} brace init
// ... with similiar behavior as array
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
#else
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
#endif
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
#define CK_TILE_USE_LAUNCH_BOUNDS 1
#endif
#ifndef CK_TILE_TIME_KERNEL
#define CK_TILE_TIME_KERNEL 1
#endif
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
#define CK_TILE_MIN_BLOCK_PER_CU 2
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
#define CK_TILE_USE_AMD_BUFFER_STORE 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
include/ck_tile/core/container/array.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
// use aggregate initialization for this type
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
template
<
typename
T_
,
index_t
N_
>
struct
array
{
using
value_type
=
T_
;
static
constexpr
index_t
N
=
N_
;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type
data
[
N
];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE
constexpr
array
()
:
data
{}
{}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
CK_TILE_HOST_DEVICE
constexpr
array
(
std
::
initializer_list
<
value_type
>
ilist
)
{
constexpr
index_t
list_size
=
std
::
initializer_list
<
value_type
>
{}.
size
();
static_assert
(
list_size
<=
N
,
"out of bound"
);
index_t
i
=
0
;
value_type
vlast
=
value_type
{};
for
(
const
value_type
&
val
:
ilist
)
{
data
[
i
]
=
val
;
vlast
=
val
;
++
i
;
}
for
(;
i
<
N
;
++
i
)
{
data
[
i
]
=
vlast
;
}
}
template
<
typename
Y
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
Y
,
value_type
>
||
std
::
is_constructible_v
<
Y
,
value_type
>>>
CK_TILE_HOST_DEVICE
explicit
constexpr
array
(
Y
c
)
{
for
(
auto
i
=
0
;
i
<
size
();
i
++
)
data
[
i
]
=
static_cast
<
value_type
>
(
c
);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE
static
constexpr
auto
size
()
{
return
N
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
is_static_v
<
value_type
>
;
}
// clang-format off
CK_TILE_HOST_DEVICE
constexpr
auto
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get
()
const
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get
(
index_t
i
)
{
return
data
[
i
];
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get
(
index_t
i
)
const
{
return
data
[
i
];
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
get
()
{
return
data
[
I
];
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get
()
const
{
return
data
[
I
];
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
get
(
number
<
I
>
)
{
return
data
[
I
];
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get
(
number
<
I
>
)
const
{
return
data
[
I
];
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
index_t
i
)
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
index_t
i
)
const
{
return
get
(
i
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
()
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
CK_TILE_HOST_DEVICE
constexpr
const
value_type
&
operator
[](
index_t
i
)
const
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
value_type
&
operator
[](
index_t
i
)
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
value_type
&
operator
()(
index_t
i
)
{
return
get
(
i
);
}
// TODO: compatible
#if 0
template <typename ArrayLike>
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
{
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = arr[i];
}
return *this;
}
#endif
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
#define AR_AS_COM_() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_as
()
{
AR_AS_COM_
();
return
reinterpret_cast
<
array
<
Tx
,
vx
>&>
(
data
);
}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_as
()
const
{
AR_AS_COM_
();
return
reinterpret_cast
<
const
array
<
Tx
,
vx
>&>
(
data
);
}
// below index is for index *AFTER* type convert, not before
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_as
(
index_t
i
)
{
AR_AS_COM_
();
return
reinterpret_cast
<
array
<
Tx
,
vx
>&>
(
data
).
at
(
i
);
}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_as
(
index_t
i
)
const
{
AR_AS_COM_
();
return
reinterpret_cast
<
const
array
<
Tx
,
vx
>&>
(
data
).
at
(
i
);
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_as
(
number
<
I
>
)
{
AR_AS_COM_
();
return
reinterpret_cast
<
array
<
Tx
,
vx
>&>
(
data
).
at
(
number
<
I
>
{});
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_as
(
number
<
I
>
)
const
{
AR_AS_COM_
();
return
reinterpret_cast
<
const
array
<
Tx
,
vx
>&>
(
data
).
at
(
number
<
I
>
{});
}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
index_t
i
,
const
Tx
&
x
)
{
AR_AS_COM_
();
reinterpret_cast
<
array
<
Tx
,
vx
>&>
(
data
).
at
(
i
)
=
x
;
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
number
<
I
>
,
const
Tx
&
x
)
{
AR_AS_COM_
();
reinterpret_cast
<
array
<
Tx
,
vx
>&>
(
data
).
at
(
number
<
I
>
{})
=
x
;
}
#undef AR_AS_COM_
// clang-format on
};
// empty Array
template
<
typename
T
>
struct
array
<
T
,
0
>
{
using
value_type
=
T
;
CK_TILE_HOST_DEVICE
constexpr
array
()
{}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
size
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
is_static_v
<
T
>
;
};
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"array{size: 0, data: []}"
);
}
};
template
<
typename
>
struct
vector_traits
;
// specialization for array
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
array
<
T
,
N
>>
{
using
scalar_type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
namespace
details
{
template
<
class
>
struct
is_ref_wrapper
:
std
::
false_type
{
};
template
<
class
T
>
struct
is_ref_wrapper
<
std
::
reference_wrapper
<
T
>>
:
std
::
true_type
{
};
template
<
class
T
>
using
not_ref_wrapper
=
std
::
negation
<
is_ref_wrapper
<
std
::
decay_t
<
T
>>>
;
template
<
class
D
,
class
...
>
struct
return_type_helper
{
using
type
=
D
;
};
template
<
class
...
Ts
>
struct
return_type_helper
<
void
,
Ts
...
>
:
std
::
common_type
<
Ts
...
>
{
static_assert
(
std
::
conjunction_v
<
not_ref_wrapper
<
Ts
>
...
>
,
"Ts cannot contain reference_wrappers when D is void"
);
};
template
<
class
D
,
class
...
Ts
>
using
return_type
=
array
<
typename
return_type_helper
<
D
,
Ts
...
>::
type
,
sizeof
...(
Ts
)
>
;
}
// namespace details
template
<
typename
D
=
void
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
details
::
return_type
<
D
,
Ts
...
>
make_array
(
Ts
&&
...
ts
)
{
return
{
std
::
forward
<
Ts
>
(
ts
)...};
}
// // make empty array
// template <typename T>
// CK_TILE_HOST_DEVICE constexpr auto make_array()
// {
// return array<T, 0>{};
// }
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
template
<
typename
T
,
index_t
Size
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_array_with
(
std
::
initializer_list
<
T
>
ilist
)
{
return
array
<
T
,
Size
>
(
ilist
);
}
template
<
typename
T
,
index_t
Size
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
==
(
const
array
<
T
,
Size
>&
a
,
const
array
<
T
,
Size
>&
b
)
{
bool
same
=
true
;
for
(
index_t
i
=
0
;
i
<
Size
;
++
i
)
{
if
(
a
[
i
]
!=
b
[
i
])
{
same
=
false
;
break
;
}
}
return
same
;
}
template
<
typename
T
,
index_t
Size
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
const
array
<
T
,
Size
>&
a
,
const
array
<
T
,
Size
>&
b
)
{
return
!
(
a
==
b
);
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
{
static_assert
(
N
<=
X
::
size
(),
""
);
array
<
T
,
N
>
arr
;
static_for
<
0
,
N
,
1
>
{}([
&
x
,
&
arr
](
auto
i
)
{
arr
(
i
)
=
x
[
i
];
});
return
arr
;
}
}
// namespace ck_tile
include/ck_tile/core/container/container_helper.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
template
<
typename
TData
,
index_t
NSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_push_back
(
const
array
<
TData
,
NSize
>&
a
,
const
TData
&
x
)
{
array
<
TData
,
NSize
+
1
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
r
,
&
a
](
auto
i
)
constexpr
{
r
(
i
)
=
a
[
i
];
});
r
[
number
<
NSize
>
{}]
=
x
;
return
r
;
}
template
<
typename
...
Ts
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_push_front
(
const
tuple
<
Ts
...
>&
a
,
const
T
&
x
)
{
return
container_concat
(
make_tuple
(
x
),
a
);
}
template
<
typename
...
Ts
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_push_back
(
const
tuple
<
Ts
...
>&
a
,
const
T
&
x
)
{
return
container_concat
(
a
,
make_tuple
(
x
));
}
// reorder array
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_new2old
(
const
array
<
TData
,
NSize
>&
old_array
,
sequence
<
IRs
...
>
/*new2old*/
)
{
static_assert
(
NSize
==
sizeof
...(
IRs
),
"wrong! size not consistent"
);
static_assert
(
is_valid_sequence_map
<
sequence
<
IRs
...
>>
{},
"wrong! invalid reorder map"
);
return
make_array
<
remove_cvref_t
<
TData
>>
(
old_array
[
IRs
]...);
}
template
<
typename
TData
,
index_t
NSize
,
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_old2new
(
const
array
<
TData
,
NSize
>&
old_array
,
sequence
<
IRs
...
>
old2new
)
{
return
container_reorder_given_new2old
(
old_array
,
typename
sequence_map_inverse
<
decltype
(
old2new
)
>::
type
{});
}
// reorder array
template
<
typename
TData
,
index_t
NSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_new2old
(
const
array
<
TData
,
NSize
>&
old_array
,
const
map
<
index_t
,
index_t
>&
new2old
)
{
array
<
TData
,
NSize
>
new_array
;
for
(
const
auto
&
[
new_pos
,
old_pos
]
:
new2old
)
{
new_array
(
new_pos
)
=
old_array
[
old_pos
];
}
return
new_array
;
}
template
<
typename
TData
,
index_t
NSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_old2new
(
const
array
<
TData
,
NSize
>&
old_array
,
const
map
<
index_t
,
index_t
>&
old2new
)
{
array
<
TData
,
NSize
>
new_array
;
for
(
const
auto
&
[
old_pos
,
new_pos
]
:
old2new
)
{
new_array
(
new_pos
)
=
old_array
[
old_pos
];
}
return
new_array
;
}
// reorder tuple
template
<
typename
...
Ts
,
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_new2old
(
const
tuple
<
Ts
...
>&
old_tuple
,
sequence
<
IRs
...
>
/*new2old*/
)
{
static_assert
(
sizeof
...(
Ts
)
==
sizeof
...(
IRs
),
"wrong! size not consistent"
);
static_assert
(
is_valid_sequence_map
<
sequence
<
IRs
...
>>
{},
"wrong! invalid reorder map"
);
return
make_tuple
(
old_tuple
[
number
<
IRs
>
{}]...);
}
template
<
typename
...
Ts
,
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_old2new
(
const
tuple
<
Ts
...
>&
old_tuple
,
sequence
<
IRs
...
>
old2new
)
{
return
container_reorder_given_new2old
(
old_tuple
,
typename
sequence_map_inverse
<
decltype
(
old2new
)
>::
type
{});
}
// reorder sequence
template
<
index_t
...
Is
,
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_new2old
(
sequence
<
Is
...
>
/* old_seq */
,
sequence
<
IRs
...
>
/*new2old*/
)
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
IRs
),
"wrong! size not consistent"
);
static_assert
(
is_valid_sequence_map
<
sequence
<
IRs
...
>>
{},
"wrong! invalid reorder map"
);
return
sequence
<
sequence
<
Is
...
>::
at
(
number
<
IRs
>
{})...
>
{};
}
template
<
index_t
...
Is
,
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reorder_given_old2new
(
sequence
<
Is
...
>
old_seq
,
sequence
<
IRs
...
>
/* old2new */
)
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
IRs
),
"wrong! size not consistent"
);
static_assert
(
is_valid_sequence_map
<
sequence
<
IRs
...
>>
{},
"wrong! invalid reorder map"
);
constexpr
auto
new2old
=
typename
sequence_map_inverse
<
sequence
<
IRs
...
>>::
type
{};
return
container_reorder_given_new2old
(
old_seq
,
new2old
);
}
#if 0
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template
<
typename
Container
,
typename
Reduce
,
typename
ROld
,
index_t
I
,
index_t
IEnd
,
index_t
IStep
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reduce_impl
(
const
Container
&
x
,
Reduce
reduce
,
ROld
r_old
,
number
<
I
>
i
,
number
<
IEnd
>
,
number
<
IStep
>
)
{
auto
r_new
=
reduce
(
x
[
i
],
r_old
);
if
constexpr
(
i
.
value
<
IEnd
-
IStep
)
{
return
container_reduce_impl
(
x
,
reduce
,
r_new
,
i
+
number
<
IStep
>
{},
number
<
IEnd
>
{},
number
<
IStep
>
{});
}
else
{
return
r_new
;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template
<
typename
Container
,
typename
Reduce
,
typename
Init
,
index_t
IBegin
=
0
,
index_t
IEnd
=
Container
::
size
(),
index_t
IStep
=
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reduce
(
const
Container
&
x
,
Reduce
reduce
,
Init
init
,
number
<
IBegin
>
=
number
<
0
>
{},
number
<
IEnd
>
=
number
<
Container
::
size
()
>
{},
number
<
IStep
>
=
number
<
1
>
{})
{
static_assert
((
IEnd
-
IBegin
)
%
IStep
==
0
,
"wrong!"
);
if
constexpr
(
IEnd
>
IBegin
)
{
return
container_reduce_impl
(
x
,
reduce
,
init
,
number
<
IBegin
>
{},
number
<
IEnd
>
{},
number
<
IStep
>
{});
}
else
{
return
init
;
}
}
#endif
template
<
typename
TData
,
index_t
NSize
,
typename
Reduce
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reverse_inclusive_scan
(
const
array
<
TData
,
NSize
>&
x
,
Reduce
f
,
TData
init
)
{
array
<
TData
,
NSize
>
y
;
TData
r
=
init
;
static_for
<
NSize
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
r
=
f
(
r
,
x
[
i
]);
y
(
i
)
=
r
;
});
r
=
f
(
r
,
x
[
number
<
0
>
{}]);
y
(
number
<
0
>
{})
=
r
;
return
y
;
}
template
<
typename
TData
,
index_t
NSize
,
typename
Reduce
,
typename
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reverse_exclusive_scan
(
const
array
<
TData
,
NSize
>&
x
,
Reduce
f
,
Init
init
)
{
#if 0
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(number<0>{}) = r;
return y;
#else
array
<
TData
,
NSize
>
y
;
TData
r
=
init
;
for
(
index_t
i
=
NSize
-
1
;
i
>
0
;
--
i
)
{
y
(
i
)
=
r
;
r
=
f
(
r
,
x
[
i
]);
}
y
(
0
)
=
r
;
return
y
;
#endif
}
template
<
index_t
...
Is
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reverse_exclusive_scan
(
const
sequence
<
Is
...
>&
seq
,
Reduce
f
,
number
<
Init
>
)
{
return
reverse_exclusive_scan_sequence
(
seq
,
f
,
number
<
Init
>
{});
}
#if 0
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template
<
typename
...
Xs
,
typename
Reduce
,
index_t
I
,
typename
YOld
,
typename
ROld
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reverse_exclusive_scan_impl
(
const
tuple
<
Xs
...
>&
x
,
Reduce
reduce
,
number
<
I
>
i
,
YOld
y_old
,
ROld
r_old
)
{
auto
r_new
=
reduce
(
x
[
i
],
r_old
);
auto
y_new
=
container_push_front
(
y_old
,
r_new
);
if
constexpr
(
i
.
value
>
1
)
{
// recursively call f/fs
return
container_reverse_exclusive_scan_impl
(
x
,
reduce
,
i
-
number
<
1
>
{},
y_new
,
r_new
);
}
else
{
return
y_new
;
}
}
template
<
typename
...
Xs
,
typename
Reduce
,
typename
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reverse_exclusive_scan
(
const
tuple
<
Xs
...
>&
x
,
Reduce
reduce
,
Init
init
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
container_reverse_exclusive_scan_impl
(
x
,
reduce
,
number
<
NSize
-
1
>
{},
make_tuple
(
init
),
init
);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
template
<
typename
...
Xs
,
typename
Reduce
,
typename
TData
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_reverse_inclusive_scan
(
const
tuple
<
Xs
...
>&
x
,
Reduce
f
,
TData
init
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
y
;
TData
r
=
init
;
static_for
<
NSize
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
r
=
f
(
r
,
x
[
i
]);
y
(
i
)
=
r
;
});
r
=
f
(
r
,
x
[
number
<
0
>
{}]);
y
(
number
<
0
>
{})
=
r
;
return
y
;
}
template
<
typename
X
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_concat
(
const
X
&
x
,
const
Ys
&
...
ys
)
{
return
container_concat
(
x
,
container_concat
(
ys
...));
}
template
<
typename
T
,
index_t
NX
,
index_t
NY
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_concat
(
const
array
<
T
,
NX
>&
ax
,
const
array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
<
T
>
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_concat
(
const
tuple
<
X
...
>&
tx
,
const
tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
CK_TILE_HOST_DEVICE
constexpr
auto
container_concat
(
const
Container
&
x
)
{
return
x
;
}
template
<
typename
T
,
index_t
N
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_container_subset
(
const
array
<
T
,
N
>&
arr
,
sequence
<
Is
...
>
)
{
static_assert
(
N
>=
sizeof
...(
Is
),
"wrong! size"
);
if
constexpr
(
sizeof
...(
Is
)
>
0
)
{
return
make_array
<
T
>
(
arr
[
Is
]...);
}
else
{
return
array
<
T
,
0
>
{};
}
}
template
<
typename
...
Ts
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_container_subset
(
const
tuple
<
Ts
...
>&
tup
,
sequence
<
Is
...
>
)
{
static_assert
(
sizeof
...(
Ts
)
>=
sizeof
...(
Is
),
"wrong! size"
);
if
constexpr
(
sizeof
...(
Is
)
>
0
)
{
return
make_tuple
(
tup
[
number
<
Is
>
{}]...);
}
else
{
return
tuple
<>
{};
}
}
template
<
typename
T
,
index_t
N
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
void
set_container_subset
(
array
<
T
,
N
>&
y
,
sequence
<
Is
...
>
picks
,
const
array
<
T
,
sizeof
...(
Is
)
>&
x
)
{
static_assert
(
N
>=
sizeof
...(
Is
),
"wrong! size"
);
if
constexpr
(
sizeof
...(
Is
)
>
0
)
{
for
(
index_t
i
=
0
;
i
<
picks
.
size
();
++
i
)
{
y
(
picks
[
i
])
=
x
[
i
];
}
}
}
template
<
typename
Y
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
void
set_container_subset
(
Y
&
y
,
sequence
<
Is
...
>
picks
,
const
X
&
x
)
{
static_assert
(
Y
::
size
()
>=
sizeof
...(
Is
)
&&
X
::
size
()
==
sizeof
...(
Is
),
"wrong! size"
);
if
constexpr
(
sizeof
...(
Is
)
>
0
)
{
static_for
<
0
,
sizeof
...(
Is
),
1
>
{}([
&
](
auto
i
)
{
y
(
picks
[
i
])
=
x
[
i
];
});
}
}
// return the index of first occurance in the sequence.
// return seq.size(), if not found
template
<
index_t
...
Is
>
constexpr
index_t
container_find
(
sequence
<
Is
...
>
seq
,
index_t
value
)
{
for
(
auto
i
=
0
;
i
<
seq
.
size
();
i
++
)
{
if
(
seq
[
i
]
==
value
)
return
i
;
}
return
seq
.
size
();
}
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_to_tuple_of_number
(
sequence
<
Is
...
>
)
{
using
Seq
=
sequence
<
Is
...
>
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
index_t
tmp
=
Seq
::
at
(
i
);
return
number
<
tmp
>
{};
},
number
<
Seq
::
size
()
>
{});
}
#if 0
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#else
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#endif
}
// namespace ck_tile
include/ck_tile/core/container/map.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace
ck_tile
{
// naive map
template
<
typename
key
,
typename
data
,
index_t
max_size
=
128
>
struct
map
{
using
pair_type
=
tuple
<
key
,
data
>
;
using
impl_type
=
array
<
pair_type
,
max_size
>
;
impl_type
impl_
;
index_t
size_
;
struct
iterator
{
impl_type
&
impl_
;
index_t
pos_
;
CK_TILE_HOST_DEVICE
constexpr
iterator
(
impl_type
&
impl
,
index_t
pos
)
:
impl_
{
impl
},
pos_
{
pos
}
{
}
CK_TILE_HOST_DEVICE
constexpr
iterator
&
operator
++
()
{
pos_
++
;
return
*
this
;
}
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
const
iterator
&
other
)
const
{
return
other
.
pos_
!=
pos_
;
}
CK_TILE_HOST_DEVICE
constexpr
pair_type
&
operator
*
()
{
return
impl_
.
at
(
pos_
);
}
};
struct
const_iterator
{
const
impl_type
&
impl_
;
index_t
pos_
;
CK_TILE_HOST_DEVICE
constexpr
const_iterator
(
const
impl_type
&
impl
,
index_t
pos
)
:
impl_
{
impl
},
pos_
{
pos
}
{
}
CK_TILE_HOST_DEVICE
constexpr
const_iterator
&
operator
++
()
{
pos_
++
;
return
*
this
;
}
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
const
const_iterator
&
other
)
const
{
return
other
.
pos_
!=
pos_
;
}
CK_TILE_HOST_DEVICE
constexpr
const
pair_type
&
operator
*
()
const
{
return
impl_
.
at
(
pos_
);
}
};
CK_TILE_HOST_DEVICE
constexpr
map
()
:
impl_
{},
size_
{
0
}
{}
CK_TILE_HOST_DEVICE
constexpr
index_t
size
()
const
{
return
size_
;
}
CK_TILE_HOST_DEVICE
void
clear
()
{
size_
=
0
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
find_position
(
const
key
&
k
)
const
{
for
(
index_t
i
=
0
;
i
<
size
();
i
++
)
{
if
(
impl_
[
i
].
template
at
<
0
>()
==
k
)
{
return
i
;
}
}
return
size_
;
}
CK_TILE_HOST_DEVICE
constexpr
const_iterator
find
(
const
key
&
k
)
const
{
return
const_iterator
{
impl_
,
find_position
(
k
)};
}
CK_TILE_HOST_DEVICE
constexpr
iterator
find
(
const
key
&
k
)
{
return
iterator
{
impl_
,
find_position
(
k
)};
}
CK_TILE_HOST_DEVICE
constexpr
const
data
&
operator
[](
const
key
&
k
)
const
{
const
auto
it
=
find
(
k
);
// FIXME
// assert(it.pos_ < size());
return
impl_
[
it
.
pos_
].
template
at
<
1
>();
}
CK_TILE_HOST_DEVICE
constexpr
data
&
operator
()(
const
key
&
k
)
{
auto
it
=
find
(
k
);
// if entry not found
if
(
it
.
pos_
==
size
())
{
impl_
(
it
.
pos_
).
template
at
<
0
>()
=
k
;
size_
++
;
}
// FIXME
// assert(size_ <= max_size);
return
impl_
(
it
.
pos_
).
template
at
<
1
>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE
constexpr
const_iterator
begin
()
const
{
return
const_iterator
{
impl_
,
0
};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE
constexpr
const_iterator
end
()
const
{
return
const_iterator
{
impl_
,
size_
};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE
constexpr
iterator
begin
()
{
return
iterator
{
impl_
,
0
};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE
constexpr
iterator
end
()
{
return
iterator
{
impl_
,
size_
};
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"map{size_: %d, "
,
size_
);
//
printf
(
"impl_: ["
);
//
for
(
const
auto
&
[
k
,
d
]
:
*
this
)
{
printf
(
"{key: "
);
print
(
k
);
printf
(
", data: "
);
print
(
d
);
printf
(
"}, "
);
}
//
printf
(
"]"
);
//
printf
(
"}"
);
}
};
}
// namespace ck_tile
include/ck_tile/core/container/meta_data_buffer.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <cstddef>
namespace
ck_tile
{
// TODO: this structure is not intented to be used by user
template
<
index_t
MaxSize
>
struct
meta_data_buffer
{
CK_TILE_HOST_DEVICE
constexpr
meta_data_buffer
()
:
buffer_
{},
size_
{
0
}
{}
template
<
typename
X
,
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
meta_data_buffer
(
const
X
&
x
,
const
Xs
&
...
xs
)
:
buffer_
{},
size_
{
0
}
{
push
(
x
,
xs
...);
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
void
push
(
const
T
&
data
)
{
if
constexpr
(
!
std
::
is_empty_v
<
T
>
)
{
constexpr
index_t
size
=
sizeof
(
T
);
auto
tmp
=
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
buffer_
(
size_
)
=
tmp
[
i
];
size_
++
;
}
}
}
template
<
typename
X
,
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
void
push
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
push
(
x
);
push
(
xs
...);
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
pop
(
index_t
&
pos
)
const
{
T
data
;
if
constexpr
(
!
std
::
is_empty_v
<
T
>
)
{
constexpr
index_t
size
=
sizeof
(
T
);
array
<
std
::
byte
,
size
>
tmp
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
tmp
(
i
)
=
buffer_
[
pos
];
pos
++
;
}
data
=
bit_cast
<
T
>
(
tmp
);
}
return
data
;
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
get
(
index_t
pos
)
const
{
constexpr
index_t
size
=
sizeof
(
T
);
array
<
std
::
byte
,
size
>
tmp
;
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
tmp
(
i
)
=
buffer_
[
pos
];
pos
++
;
}
auto
data
=
bit_cast
<
T
>
(
tmp
);
return
data
;
}
//
array
<
std
::
byte
,
MaxSize
>
buffer_
;
index_t
size_
=
0
;
};
}
// namespace ck_tile
include/ck_tile/core/container/multi_index.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
// Don't use tihs directly. This is for old CK's internal usage,
// in the future always use array instead
template
<
index_t
N
>
using
multi_index
=
array
<
index_t
,
N
>
;
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_multi_index
(
Xs
&&
...
xs
)
{
return
make_array
<
index_t
>
(
index_t
{
xs
}...);
}
template
<
index_t
NSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_zero_multi_index
()
{
return
unpack
([](
auto
...
xs
)
{
return
make_multi_index
(
xs
...);
},
typename
uniform_sequence_gen
<
NSize
,
0
>::
type
{});
}
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_multi_index
(
const
T
&
x
)
{
return
unpack
([](
auto
...
ys
)
{
return
make_multi_index
(
ys
...);
},
x
);
}
template
<
index_t
NSize
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
multi_index
<
NSize
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
size
()
==
NSize
,
"wrong! size not the same"
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
return
y
;
}
template
<
index_t
NSize
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
multi_index
<
NSize
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
size
()
==
NSize
,
"wrong! size not the same"
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
return
y
;
}
template
<
index_t
NSize
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
multi_index
<
NSize
>&
a
,
const
T
&
b
)
{
using
type
=
multi_index
<
NSize
>
;
static_assert
(
T
::
size
()
==
NSize
,
"wrong! size not the same"
);
type
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
a
[
i
]
+
b
[
i
];
});
return
r
;
}
template
<
index_t
NSize
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
multi_index
<
NSize
>&
a
,
const
T
&
b
)
{
using
type
=
multi_index
<
NSize
>
;
static_assert
(
T
::
size
()
==
NSize
,
"wrong! size not the same"
);
type
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
a
[
i
]
-
b
[
i
];
});
return
r
;
}
template
<
index_t
NSize
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
multi_index
<
NSize
>&
a
,
const
T
&
b
)
{
using
type
=
multi_index
<
NSize
>
;
static_assert
(
T
::
size
()
==
NSize
,
"wrong! size not the same"
);
type
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
a
[
i
]
*
b
[
i
];
});
return
r
;
}
// multi_index = index_t * multi_index
template
<
index_t
NSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
index_t
a
,
const
multi_index
<
NSize
>&
x
)
{
multi_index
<
NSize
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
a
*
x
[
i
];
});
return
r
;
}
// multi_index = multi_index * index_t
template
<
index_t
NSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
multi_index
<
NSize
>&
x
,
index_t
a
)
{
return
a
*
x
;
}
}
// namespace ck_tile
include/ck_tile/core/container/sequence.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
template
<
index_t
,
index_t
,
index_t
>
struct
static_for
;
template
<
index_t
...>
struct
sequence
;
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
;
template
<
typename
>
struct
sequence_reverse
;
template
<
typename
>
struct
sequence_map_inverse
;
template
<
typename
>
struct
is_valid_sequence_map
;
template
<
index_t
I
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_front
(
sequence
<
I
,
Is
...
>
);
template
<
typename
Seq
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_back
(
Seq
);
namespace
impl
{
// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element");
template
<
index_t
I
,
typename
...
Ts
>
using
at_index_t
=
__type_pack_element
<
I
,
Ts
...
>
;
}
// namespace impl
// we could implement as below, similiar to std. But let's reduce the symbol name...
// template< class T, T... Ints >
// class integer_sequence;
template
<
index_t
...
Is
>
struct
sequence
{
using
type
=
sequence
;
using
value_type
=
index_t
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
size
()
{
return
sizeof
...(
Is
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
true
;
};
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get
()
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
impl
::
at_index_t
<
I
,
constant
<
Is
>
...
>
{}
>
{};
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get
(
number
<
I
>
)
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
get
<
I
>
()
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
at
(
index_t
I
)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const
index_t
mData
[
size
()
+
1
]
=
{
Is
...,
0
};
return
mData
[
I
];
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
at
()
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
impl
::
at_index_t
<
I
,
constant
<
Is
>
...
>
{}
>
{};
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
at
(
number
<
I
>
)
{
static_assert
(
I
<
size
(),
"wrong! I too large"
);
return
number
<
get
<
I
>
()
>
{};
}
template
<
typename
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
[](
I
i
)
const
{
return
at
(
i
);
}
template
<
index_t
...
IRs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
reorder_new_to_old
(
sequence
<
IRs
...
>
/*new2old*/
)
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
IRs
),
"wrong! reorder map should have the same size as sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
sequence
<
IRs
...
>>::
value
,
"wrong! invalid reorder map"
);
return
sequence
<
type
::
get
(
number
<
IRs
>
{})...
>
{};
}
// MapOld2New is sequence<...>
template
<
typename
MapOld2New
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
reorder_old_to_new
(
MapOld2New
)
{
static_assert
(
MapOld2New
::
size
()
==
size
(),
"wrong! reorder map should have the same size as sequence to be rerodered"
);
static_assert
(
is_valid_sequence_map
<
MapOld2New
>::
value
,
"wrong! invalid reorder map"
);
return
reorder_new_to_old
(
typename
sequence_map_inverse
<
MapOld2New
>::
type
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
reverse
()
{
return
typename
sequence_reverse
<
type
>::
type
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
front
()
{
static_assert
(
size
()
>
0
,
"wrong!"
);
return
get
(
number
<
0
>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
back
()
{
static_assert
(
size
()
>
0
,
"wrong!"
);
return
get
(
number
<
size
()
-
1
>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
pop_front
()
{
return
sequence_pop_front
(
type
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
pop_back
()
{
return
sequence_pop_back
(
type
{});
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_front
(
sequence
<
Xs
...
>
)
{
return
sequence
<
Xs
...,
Is
...
>
{};
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_front
(
number
<
Xs
>
...)
{
return
sequence
<
Xs
...,
Is
...
>
{};
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_back
(
sequence
<
Xs
...
>
)
{
return
sequence
<
Is
...,
Xs
...
>
{};
}
template
<
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
push_back
(
number
<
Xs
>
...)
{
return
sequence
<
Is
...,
Xs
...
>
{};
}
// pickup element at index <Ids...>
template
<
index_t
...
Ids
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
extract
(
number
<
Ids
>
...)
{
return
sequence
<
type
::
get
(
number
<
Ids
>
{})...
>
{};
}
template
<
index_t
...
Ids
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
extract
(
sequence
<
Ids
...
>
)
{
return
sequence
<
type
::
get
(
number
<
Ids
>
{})...
>
{};
}
// modify element at index "I" with value "X"
template
<
index_t
I
,
index_t
X
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
modify
(
number
<
I
>
,
number
<
X
>
)
{
static_assert
(
I
<
size
(),
"wrong!"
);
using
seq_split
=
sequence_split
<
type
,
I
>
;
constexpr
auto
seq_left
=
typename
seq_split
::
left_type
{};
constexpr
auto
seq_right
=
typename
seq_split
::
right_type
{}.
pop_front
();
return
seq_left
.
push_back
(
number
<
X
>
{}).
push_back
(
seq_right
);
}
template
<
typename
F
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
transform
(
F
f
)
{
return
sequence
<
f
(
Is
)...
>
{};
}
CK_TILE_HOST_DEVICE
static
void
print
()
{
printf
(
"sequence{size: %d, data: ["
,
size
());
((
printf
(
"%d "
,
Is
)),
...);
printf
(
"]}"
);
}
};
namespace
impl
{
template
<
typename
T
,
T
...
Ints
>
struct
__integer_sequence
;
template
<
index_t
...
Ints
>
struct
__integer_sequence
<
index_t
,
Ints
...
>
{
using
seq_type
=
sequence
<
Ints
...
>
;
};
}
// namespace impl
// similiar
template
<
index_t
N
>
using
make_index_sequence
=
typename
__make_integer_seq
<
impl
::
__integer_sequence
,
index_t
,
N
>::
seq_type
;
// merge sequence
template
<
typename
Seq
,
typename
...
Seqs
>
struct
sequence_merge
{
using
type
=
typename
sequence_merge
<
Seq
,
typename
sequence_merge
<
Seqs
...
>::
type
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_merge
<
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>>
{
using
type
=
sequence
<
Xs
...,
Ys
...
>
;
};
template
<
typename
Seq
>
struct
sequence_merge
<
Seq
>
{
using
type
=
Seq
;
};
// generate sequence
template
<
index_t
NSize
,
typename
F
>
struct
sequence_gen
{
template
<
index_t
IBegin
,
index_t
NRemain
,
typename
G
>
struct
sequence_gen_impl
{
static
constexpr
index_t
NRemainLeft
=
NRemain
/
2
;
static
constexpr
index_t
NRemainRight
=
NRemain
-
NRemainLeft
;
static
constexpr
index_t
IMiddle
=
IBegin
+
NRemainLeft
;
using
type
=
typename
sequence_merge
<
typename
sequence_gen_impl
<
IBegin
,
NRemainLeft
,
G
>::
type
,
typename
sequence_gen_impl
<
IMiddle
,
NRemainRight
,
G
>::
type
>::
type
;
};
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
1
,
G
>
{
static
constexpr
index_t
Is
=
G
{}(
number
<
I
>
{});
using
type
=
sequence
<
Is
>
;
};
template
<
index_t
I
,
typename
G
>
struct
sequence_gen_impl
<
I
,
0
,
G
>
{
using
type
=
sequence
<>
;
};
using
type
=
typename
sequence_gen_impl
<
0
,
NSize
,
F
>::
type
;
};
// arithmetic sequence
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
struct
arithmetic_sequence_gen
{
struct
F
{
CK_TILE_HOST_DEVICE
constexpr
index_t
operator
()(
index_t
i
)
const
{
return
i
*
Increment
+
IBegin
;
}
};
using
type0
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type1
=
sequence
<>
;
static
constexpr
bool
kHasContent
=
(
Increment
>
0
&&
IBegin
<
IEnd
)
||
(
Increment
<
0
&&
IBegin
>
IEnd
);
using
type
=
typename
std
::
conditional
<
kHasContent
,
type0
,
type1
>::
type
;
};
template
<
index_t
IEnd
>
struct
arithmetic_sequence_gen
<
0
,
IEnd
,
1
>
{
using
type
=
make_index_sequence
<
IEnd
>
;
};
// uniform sequence
template
<
index_t
NSize
,
index_t
I
>
struct
uniform_sequence_gen
{
struct
F
{
CK_TILE_HOST_DEVICE
constexpr
index_t
operator
()(
index_t
)
const
{
return
I
;
}
};
using
type
=
typename
sequence_gen
<
NSize
,
F
>::
type
;
};
// reverse inclusive scan (with init) sequence
template
<
typename
,
typename
,
index_t
>
struct
sequence_reverse_inclusive_scan
;
template
<
index_t
I
,
index_t
...
Is
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
sequence
<
I
,
Is
...
>
,
Reduce
,
Init
>
{
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
sequence
<
Is
...
>
,
Reduce
,
Init
>::
type
;
static
constexpr
index_t
new_reduce
=
Reduce
{}(
I
,
old_scan
{}.
front
());
using
type
=
typename
sequence_merge
<
sequence
<
new_reduce
>
,
old_scan
>::
type
;
};
template
<
index_t
I
,
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
sequence
<
I
>
,
Reduce
,
Init
>
{
using
type
=
sequence
<
Reduce
{}(
I
,
Init
)
>
;
};
template
<
typename
Reduce
,
index_t
Init
>
struct
sequence_reverse_inclusive_scan
<
sequence
<>
,
Reduce
,
Init
>
{
using
type
=
sequence
<>
;
};
// split sequence
template
<
typename
Seq
,
index_t
I
>
struct
sequence_split
{
static
constexpr
index_t
NSize
=
Seq
{}.
size
();
using
range0
=
typename
arithmetic_sequence_gen
<
0
,
I
,
1
>::
type
;
using
range1
=
typename
arithmetic_sequence_gen
<
I
,
NSize
,
1
>::
type
;
using
left_type
=
decltype
(
Seq
::
extract
(
range0
{}));
using
right_type
=
decltype
(
Seq
::
extract
(
range1
{}));
};
#if 0
// reverse sequence
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
struct sequence_reverse<sequence<I>>
{
using type = sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<sequence<I0, I1>>
{
using type = sequence<I1, I0>;
};
#endif
namespace
impl
{
template
<
typename
Id
,
index_t
...
Ns
>
struct
seq_reverse
;
template
<
index_t
...
Ids
,
index_t
...
Ns
>
struct
seq_reverse
<
sequence
<
Ids
...
>
,
Ns
...
>
{
template
<
index_t
I
>
using
element
=
impl
::
at_index_t
<
I
,
constant
<
Ns
>
...
>
;
using
type
=
sequence
<
element
<
(
sizeof
...(
Ns
)
-
1
-
Ids
)
>::
value
...
>
;
};
}
// namespace impl
template
<
index_t
...
Ns
>
struct
sequence_reverse
<
sequence
<
Ns
...
>>
:
impl
::
seq_reverse
<
make_index_sequence
<
sizeof
...(
Ns
)
>
,
Ns
...
>
{
};
// template <index_t... Ns>
// using sequence_reverse_t = typename sequence_reverse<Ns...>::type;
#if 1
template
<
typename
Reduce
,
typename
Seq
,
typename
...
Seqs
>
struct
sequence_reduce
{
using
type
=
typename
sequence_reduce
<
Reduce
,
Seq
,
typename
sequence_reduce
<
Reduce
,
Seqs
...
>::
type
>::
type
;
};
template
<
typename
Reduce
,
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_reduce
<
Reduce
,
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>>
{
using
type
=
sequence
<
Reduce
{}(
Xs
,
Ys
)...
>
;
};
template
<
typename
Reduce
,
typename
Seq
>
struct
sequence_reduce
<
Reduce
,
Seq
>
{
using
type
=
Seq
;
};
#endif
template
<
typename
Values
,
typename
Ids
,
typename
Compare
>
struct
sequence_sort_impl
{
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
{
static
constexpr
bool
choose_left
=
LeftValues
::
front
()
<
RightValues
::
front
();
static
constexpr
index_t
chosen_value
=
choose_left
?
LeftValues
::
front
()
:
RightValues
::
front
();
static
constexpr
index_t
chosen_id
=
choose_left
?
LeftIds
::
front
()
:
RightIds
::
front
();
using
new_merged_values
=
decltype
(
MergedValues
::
push_back
(
number
<
chosen_value
>
{}));
using
new_merged_ids
=
decltype
(
MergedIds
::
push_back
(
number
<
chosen_id
>
{}));
using
new_left_values
=
typename
std
::
conditional
<
choose_left
,
decltype
(
LeftValues
::
pop_front
()),
LeftValues
>::
type
;
using
new_left_ids
=
typename
std
::
conditional
<
choose_left
,
decltype
(
LeftIds
::
pop_front
()),
LeftIds
>::
type
;
using
new_right_values
=
typename
std
::
conditional
<
choose_left
,
RightValues
,
decltype
(
RightValues
::
pop_front
())
>::
type
;
using
new_right_ids
=
typename
std
::
conditional
<
choose_left
,
RightIds
,
decltype
(
RightIds
::
pop_front
())
>::
type
;
using
merge
=
sorted_sequence_merge_impl
<
new_left_values
,
new_left_ids
,
new_right_values
,
new_right_ids
,
new_merged_values
,
new_merged_ids
,
Comp
>
;
// this is output
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
sequence
<>
,
sequence
<>
,
MergedValues
,
MergedIds
,
Comp
>
{
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
LeftValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
LeftIds
>::
type
;
};
template
<
typename
RightValues
,
typename
RightIds
,
typename
MergedValues
,
typename
MergedIds
,
typename
Comp
>
struct
sorted_sequence_merge_impl
<
sequence
<>
,
sequence
<>
,
RightValues
,
RightIds
,
MergedValues
,
MergedIds
,
Comp
>
{
using
merged_values
=
typename
sequence_merge
<
MergedValues
,
RightValues
>::
type
;
using
merged_ids
=
typename
sequence_merge
<
MergedIds
,
RightIds
>::
type
;
};
template
<
typename
LeftValues
,
typename
LeftIds
,
typename
RightValues
,
typename
RightIds
,
typename
Comp
>
struct
sorted_sequence_merge
{
using
merge
=
sorted_sequence_merge_impl
<
LeftValues
,
LeftIds
,
RightValues
,
RightIds
,
sequence
<>
,
sequence
<>
,
Comp
>
;
using
merged_values
=
typename
merge
::
merged_values
;
using
merged_ids
=
typename
merge
::
merged_ids
;
};
static
constexpr
index_t
nsize
=
Values
::
size
();
using
split_unsorted_values
=
sequence_split
<
Values
,
nsize
/
2
>
;
using
split_unsorted_ids
=
sequence_split
<
Ids
,
nsize
/
2
>
;
using
left_unsorted_values
=
typename
split_unsorted_values
::
left_type
;
using
left_unsorted_ids
=
typename
split_unsorted_ids
::
left_type
;
using
left_sort
=
sequence_sort_impl
<
left_unsorted_values
,
left_unsorted_ids
,
Compare
>
;
using
left_sorted_values
=
typename
left_sort
::
sorted_values
;
using
left_sorted_ids
=
typename
left_sort
::
sorted_ids
;
using
right_unsorted_values
=
typename
split_unsorted_values
::
right_type
;
using
right_unsorted_ids
=
typename
split_unsorted_ids
::
right_type
;
using
right_sort
=
sequence_sort_impl
<
right_unsorted_values
,
right_unsorted_ids
,
Compare
>
;
using
right_sorted_values
=
typename
right_sort
::
sorted_values
;
using
right_sorted_ids
=
typename
right_sort
::
sorted_ids
;
using
merged_sorted
=
sorted_sequence_merge
<
left_sorted_values
,
left_sorted_ids
,
right_sorted_values
,
right_sorted_ids
,
Compare
>
;
using
sorted_values
=
typename
merged_sorted
::
merged_values
;
using
sorted_ids
=
typename
merged_sorted
::
merged_ids
;
};
template
<
index_t
ValueX
,
index_t
ValueY
,
index_t
IdX
,
index_t
IdY
,
typename
Compare
>
struct
sequence_sort_impl
<
sequence
<
ValueX
,
ValueY
>
,
sequence
<
IdX
,
IdY
>
,
Compare
>
{
static
constexpr
bool
choose_x
=
Compare
{}(
ValueX
,
ValueY
);
using
sorted_values
=
typename
std
::
conditional
<
choose_x
,
sequence
<
ValueX
,
ValueY
>
,
sequence
<
ValueY
,
ValueX
>>::
type
;
using
sorted_ids
=
typename
std
::
conditional
<
choose_x
,
sequence
<
IdX
,
IdY
>
,
sequence
<
IdY
,
IdX
>>::
type
;
};
template
<
index_t
Value
,
index_t
Id
,
typename
Compare
>
struct
sequence_sort_impl
<
sequence
<
Value
>
,
sequence
<
Id
>
,
Compare
>
{
using
sorted_values
=
sequence
<
Value
>
;
using
sorted_ids
=
sequence
<
Id
>
;
};
template
<
typename
Compare
>
struct
sequence_sort_impl
<
sequence
<>
,
sequence
<>
,
Compare
>
{
using
sorted_values
=
sequence
<>
;
using
sorted_ids
=
sequence
<>
;
};
template
<
typename
Values
,
typename
Compare
>
struct
sequence_sort
{
using
unsorted_ids
=
typename
arithmetic_sequence_gen
<
0
,
Values
::
size
(),
1
>::
type
;
using
sort
=
sequence_sort_impl
<
Values
,
unsorted_ids
,
Compare
>
;
// this is output
using
type
=
typename
sort
::
sorted_values
;
using
sorted2unsorted_map
=
typename
sort
::
sorted_ids
;
};
template
<
typename
Values
,
typename
Less
,
typename
Equal
>
struct
sequence_unique_sort
{
template
<
typename
RemainValues
,
typename
RemainIds
,
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
{
static
constexpr
index_t
current_value
=
RemainValues
::
front
();
static
constexpr
index_t
current_id
=
RemainIds
::
front
();
static
constexpr
bool
is_unique_value
=
(
current_value
!=
UniquifiedValues
::
back
());
using
new_remain_values
=
decltype
(
RemainValues
::
pop_front
());
using
new_remain_ids
=
decltype
(
RemainIds
::
pop_front
());
using
new_uniquified_values
=
typename
std
::
conditional
<
is_unique_value
,
decltype
(
UniquifiedValues
::
push_back
(
number
<
current_value
>
{})),
UniquifiedValues
>::
type
;
using
new_uniquified_ids
=
typename
std
::
conditional
<
is_unique_value
,
decltype
(
UniquifiedIds
::
push_back
(
number
<
current_id
>
{})),
UniquifiedIds
>::
type
;
using
uniquify
=
sorted_sequence_uniquify_impl
<
new_remain_values
,
new_remain_ids
,
new_uniquified_values
,
new_uniquified_ids
,
Eq
>
;
// this is output
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
template
<
typename
UniquifiedValues
,
typename
UniquifiedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify_impl
<
sequence
<>
,
sequence
<>
,
UniquifiedValues
,
UniquifiedIds
,
Eq
>
{
using
uniquified_values
=
UniquifiedValues
;
using
uniquified_ids
=
UniquifiedIds
;
};
template
<
typename
SortedValues
,
typename
SortedIds
,
typename
Eq
>
struct
sorted_sequence_uniquify
{
using
uniquify
=
sorted_sequence_uniquify_impl
<
decltype
(
SortedValues
::
pop_front
()),
decltype
(
SortedIds
::
pop_front
()),
sequence
<
SortedValues
::
front
()
>
,
sequence
<
SortedIds
::
front
()
>
,
Eq
>
;
using
uniquified_values
=
typename
uniquify
::
uniquified_values
;
using
uniquified_ids
=
typename
uniquify
::
uniquified_ids
;
};
using
sort
=
sequence_sort
<
Values
,
Less
>
;
using
sorted_values
=
typename
sort
::
type
;
using
sorted_ids
=
typename
sort
::
sorted2unsorted_map
;
using
uniquify
=
sorted_sequence_uniquify
<
sorted_values
,
sorted_ids
,
Equal
>
;
// this is output
using
type
=
typename
uniquify
::
uniquified_values
;
using
sorted2unsorted_map
=
typename
uniquify
::
uniquified_ids
;
};
template
<
typename
SeqMap
>
struct
is_valid_sequence_map
:
std
::
is_same
<
typename
arithmetic_sequence_gen
<
0
,
SeqMap
::
size
(),
1
>::
type
,
typename
sequence_sort
<
SeqMap
,
less
<
index_t
>>::
type
>
{
};
template
<
typename
SeqMap
>
struct
sequence_map_inverse
{
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
,
index_t
XRemain
>
struct
sequence_map_inverse_impl
{
static
constexpr
auto
new_y2x
=
WorkingY2X
::
modify
(
X2Y
::
get
(
number
<
XBegin
>
{}),
number
<
XBegin
>
{});
using
type
=
typename
sequence_map_inverse_impl
<
X2Y
,
decltype
(
new_y2x
),
XBegin
+
1
,
XRemain
-
1
>::
type
;
};
template
<
typename
X2Y
,
typename
WorkingY2X
,
index_t
XBegin
>
struct
sequence_map_inverse_impl
<
X2Y
,
WorkingY2X
,
XBegin
,
0
>
{
using
type
=
WorkingY2X
;
};
using
type
=
typename
sequence_map_inverse_impl
<
SeqMap
,
typename
uniform_sequence_gen
<
SeqMap
::
size
(),
0
>::
type
,
0
,
SeqMap
::
size
()
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
==
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
return
((
Xs
==
Ys
)
&&
...);
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
sequence
<
Xs
...
>
x
,
sequence
<
Ys
...
>
y
)
{
return
!
(
x
==
y
);
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
+
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
-
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
*
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
/
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
%
(
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
sequence
<
(
Xs
%
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
+
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
-
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
*
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
/
Y
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
%
(
sequence
<
Xs
...
>
,
number
<
Y
>
)
{
return
sequence
<
(
Xs
%
Y
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
+
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
-
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
*
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
/
Xs
)...
>
{};
}
template
<
index_t
Y
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
%
(
number
<
Y
>
,
sequence
<
Xs
...
>
)
{
return
sequence
<
(
Y
%
Xs
)...
>
{};
}
template
<
index_t
I
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_front
(
sequence
<
I
,
Is
...
>
)
{
return
sequence
<
Is
...
>
{};
}
template
<
typename
Seq
>
CK_TILE_HOST_DEVICE
constexpr
auto
sequence_pop_back
(
Seq
)
{
static_assert
(
Seq
::
size
()
>
0
,
"wrong! cannot pop an empty sequence!"
);
return
sequence_pop_front
(
Seq
::
reverse
()).
reverse
();
}
template
<
typename
...
Seqs
>
CK_TILE_HOST_DEVICE
constexpr
auto
merge_sequences
(
Seqs
...)
{
return
typename
sequence_merge
<
Seqs
...
>::
type
{};
}
template
<
typename
F
,
index_t
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_sequences
(
F
f
,
sequence
<
Xs
...
>
)
{
return
sequence
<
f
(
Xs
)...
>
{};
}
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_sequences
(
F
f
,
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
)
{
static_assert
(
sequence
<
Xs
...
>::
size
()
==
sequence
<
Ys
...
>::
size
(),
"Dim not the same"
);
return
sequence
<
f
(
Xs
,
Ys
)...
>
{};
}
template
<
typename
F
,
index_t
...
Xs
,
index_t
...
Ys
,
index_t
...
Zs
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_sequences
(
F
f
,
sequence
<
Xs
...
>
,
sequence
<
Ys
...
>
,
sequence
<
Zs
...
>
)
{
static_assert
(
sequence
<
Xs
...
>::
size
()
==
sequence
<
Ys
...
>::
size
()
&&
sequence
<
Xs
...
>::
size
()
==
sequence
<
Zs
...
>::
size
(),
"Dim not the same"
);
return
sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
reverse_inclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
,
Init
>::
type
{};
}
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
reverse_exclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
return
reverse_inclusive_scan_sequence
(
Seq
::
pop_front
(),
Reduce
{},
number
<
Init
>
{})
.
push_back
(
number
<
Init
>
{});
}
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
auto
inclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
return
reverse_inclusive_scan_sequence
(
Seq
{}.
reverse
(),
Reduce
{},
number
<
Init
>
{}).
reverse
();
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add
// ResultSeq TargetSeq Reduce
template
<
typename
,
typename
,
typename
>
struct
sequence_exclusive_scan
;
template
<
index_t
...
Xs
,
index_t
Y
,
index_t
...
Ys
,
typename
Reduce
>
struct
sequence_exclusive_scan
<
sequence
<
Xs
...
>
,
sequence
<
Y
,
Ys
...
>
,
Reduce
>
{
using
old_scan
=
typename
sequence_merge
<
sequence
<
Xs
...
>
,
sequence
<
Reduce
{}(
Y
,
sequence
<
Xs
...
>
{}.
back
())
>>::
type
;
using
type
=
typename
sequence_exclusive_scan
<
old_scan
,
sequence
<
Ys
...
>
,
Reduce
>::
type
;
};
template
<
index_t
...
Xs
,
index_t
Y
,
typename
Reduce
>
struct
sequence_exclusive_scan
<
sequence
<
Xs
...
>
,
sequence
<
Y
>
,
Reduce
>
{
using
type
=
sequence
<
Xs
...
>
;
};
template
<
index_t
...
Xs
,
typename
Reduce
>
struct
sequence_exclusive_scan
<
sequence
<
Xs
...
>
,
sequence
<>
,
Reduce
>
{
using
type
=
sequence
<
Xs
...
>
;
};
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
constexpr
auto
exclusive_scan_sequence
(
Seq
,
Reduce
,
number
<
Init
>
)
{
// TODO: c++20 and later can pass in Reduce with a lambda expression
return
typename
sequence_exclusive_scan
<
sequence
<
Init
>
,
Seq
,
Reduce
>::
type
{};
}
template
<
typename
Seq
>
constexpr
auto
prefix_sum_sequence
(
Seq
)
{
return
typename
sequence_exclusive_scan
<
sequence
<
0
>
,
typename
sequence_merge
<
Seq
,
sequence
<
0
>>::
type
,
plus
<
index_t
>>::
type
{};
}
template
<
typename
Seq
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
pick_sequence_elements_by_ids
(
Seq
,
sequence
<
Is
...
>
/* ids */
)
{
return
sequence
<
Seq
::
get
(
number
<
Is
>
{})...
>
{};
}
#if 1
namespace
detail
{
template
<
typename
WorkSeq
,
typename
RemainSeq
,
typename
RemainMask
>
struct
pick_sequence_elements_by_mask_impl
{
using
new_work_seq
=
typename
std
::
conditional
<
RemainMask
::
front
(),
decltype
(
WorkSeq
::
push_back
(
RemainSeq
::
front
())),
WorkSeq
>::
type
;
using
type
=
typename
pick_sequence_elements_by_mask_impl
<
new_work_seq
,
decltype
(
RemainSeq
::
pop_front
()),
decltype
(
RemainMask
::
pop_front
())
>::
type
;
};
template
<
typename
WorkSeq
>
struct
pick_sequence_elements_by_mask_impl
<
WorkSeq
,
sequence
<>
,
sequence
<>>
{
using
type
=
WorkSeq
;
};
}
// namespace detail
template
<
typename
Seq
,
typename
Mask
>
CK_TILE_HOST_DEVICE
constexpr
auto
pick_sequence_elements_by_mask
(
Seq
,
Mask
)
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
(),
"wrong!"
);
return
typename
detail
::
pick_sequence_elements_by_mask_impl
<
sequence
<>
,
Seq
,
Mask
>::
type
{};
}
namespace
detail
{
template
<
typename
WorkSeq
,
typename
RemainValues
,
typename
RemainIds
>
struct
modify_sequence_elements_by_ids_impl
{
using
new_work_seq
=
decltype
(
WorkSeq
::
modify
(
RemainIds
::
front
(),
RemainValues
::
front
()));
using
type
=
typename
modify_sequence_elements_by_ids_impl
<
new_work_seq
,
decltype
(
RemainValues
::
pop_front
()),
decltype
(
RemainIds
::
pop_front
())
>::
type
;
};
template
<
typename
WorkSeq
>
struct
modify_sequence_elements_by_ids_impl
<
WorkSeq
,
sequence
<>
,
sequence
<>>
{
using
type
=
WorkSeq
;
};
}
// namespace detail
template
<
typename
Seq
,
typename
Values
,
typename
Ids
>
CK_TILE_HOST_DEVICE
constexpr
auto
modify_sequence_elements_by_ids
(
Seq
,
Values
,
Ids
)
{
static_assert
(
Values
::
size
()
==
Ids
::
size
()
&&
Seq
::
size
()
>=
Values
::
size
(),
"wrong!"
);
return
typename
detail
::
modify_sequence_elements_by_ids_impl
<
Seq
,
Values
,
Ids
>::
type
{};
}
#endif
template
<
typename
Seq
,
typename
Reduce
,
index_t
Init
>
CK_TILE_HOST_DEVICE
constexpr
index_t
reduce_on_sequence
(
Seq
,
Reduce
f
,
number
<
Init
>
/*initial_value*/
)
{
index_t
result
=
Init
;
for
(
index_t
i
=
0
;
i
<
Seq
::
size
();
++
i
)
{
result
=
f
(
result
,
Seq
::
at
(
i
));
}
return
result
;
}
// TODO: a generic any_of for any container
template
<
typename
Seq
,
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
bool
sequence_any_of
(
Seq
,
F
f
)
{
bool
flag
=
false
;
for
(
index_t
i
=
0
;
i
<
Seq
::
size
();
++
i
)
{
flag
=
flag
||
f
(
Seq
::
at
(
i
));
}
return
flag
;
}
// TODO: a generic all_of for any container
template
<
typename
Seq
,
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
bool
sequence_all_of
(
Seq
,
F
f
)
{
bool
flag
=
true
;
for
(
index_t
i
=
0
;
i
<
Seq
::
size
();
++
i
)
{
flag
=
flag
&&
f
(
Seq
::
at
(
i
));
}
return
flag
;
}
template
<
typename
...
Seqs
>
using
sequence_merge_t
=
typename
sequence_merge
<
Seqs
...
>::
type
;
template
<
index_t
NSize
,
index_t
I
>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_sequence
(
number
<
Is
>
...)
{
return
sequence
<
Is
...
>
{};
}
// F() returns index_t
// F use default constructor, so F cannot be lambda function
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_sequence
(
F
,
number
<
N
>
)
{
return
typename
sequence_gen
<
N
,
F
>::
type
{};
}
// F() returns number<>
// F could be lambda function
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_sequence_v2
(
F
&&
f
,
number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
xs
)
{
return
make_sequence
(
f
(
xs
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
template
<
class
...
T
>
struct
tuple
;
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_sequence
(
tuple
<
number
<
Is
>
...
>
)
{
return
sequence
<
Is
...
>
{};
}
namespace
detail
{
template
<
index_t
h_idx
,
typename
SeqSortedSamples
,
typename
SeqRange
>
struct
sorted_sequence_histogram
;
template
<
index_t
h_idx
,
index_t
x
,
index_t
...
xs
,
index_t
r
,
index_t
...
rs
>
struct
sorted_sequence_histogram
<
h_idx
,
sequence
<
x
,
xs
...
>
,
sequence
<
r
,
rs
...
>>
{
template
<
typename
Histogram
>
constexpr
auto
operator
()(
Histogram
&
h
)
{
if
constexpr
(
x
<
r
)
{
h
.
template
at
<
h_idx
>()
+=
1
;
sorted_sequence_histogram
<
h_idx
,
sequence
<
xs
...
>
,
sequence
<
r
,
rs
...
>>
{}(
h
);
}
else
{
h
.
template
at
<
h_idx
+
1
>()
=
1
;
sorted_sequence_histogram
<
h_idx
+
1
,
sequence
<
xs
...
>
,
sequence
<
rs
...
>>
{}(
h
);
}
}
};
template
<
index_t
h_idx
,
index_t
x
,
index_t
r
,
index_t
...
rs
>
struct
sorted_sequence_histogram
<
h_idx
,
sequence
<
x
>
,
sequence
<
r
,
rs
...
>>
{
template
<
typename
Histogram
>
constexpr
auto
operator
()(
Histogram
&
h
)
{
if
constexpr
(
x
<
r
)
{
h
.
template
at
<
h_idx
>()
+=
1
;
}
}
};
}
// namespace detail
template
<
typename
,
index_t
>
struct
array
;
// declare for later use (array->seq utility)
// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1>
template
<
typename
SeqSortedSamples
,
index_t
r
,
index_t
...
rs
>
CK_TILE_HOST_DEVICE
constexpr
auto
histogram_sorted_sequence
(
SeqSortedSamples
,
sequence
<
r
,
rs
...
>
)
{
constexpr
auto
bins
=
sizeof
...(
rs
);
// or categories
constexpr
auto
histogram
=
[
&
]()
{
array
<
index_t
,
bins
>
h
{
0
};
// make sure this can clear all element to zero
detail
::
sorted_sequence_histogram
<
0
,
SeqSortedSamples
,
sequence
<
rs
...
>>
{}(
h
);
return
h
;
}();
return
TO_SEQUENCE
(
histogram
,
bins
);
}
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_array
(
F
&&
f
,
number
<
N
>
)
{
using
T
=
remove_cvref_t
<
decltype
(
f
(
number
<
0
>
{}))
>
;
return
unpack
([
&
f
](
auto
&&
...
is
)
{
return
array
<
T
,
N
>
{
f
(
is
)...};
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
// namespace ck_tile
include/ck_tile/core/container/span.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <cstddef>
#include <array>
#include <type_traits>
namespace
ck_tile
{
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
template
<
typename
T
>
class
span
{
public:
using
element_type
=
T
;
using
value_type
=
std
::
remove_cv_t
<
element_type
>
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
pointer
=
element_type
*
;
using
const_pointer
=
const
element_type
*
;
using
reference
=
element_type
&
;
using
const_reference
=
const
element_type
&
;
using
iterator
=
pointer
;
using
const_iterator
=
pointer
;
CK_TILE_HOST_DEVICE
constexpr
span
()
:
span
(
nullptr
,
size_type
{
0
})
{}
CK_TILE_HOST_DEVICE
constexpr
span
(
pointer
first
,
size_type
count
)
:
ptr_
(
first
),
size_
(
count
)
{
}
CK_TILE_HOST_DEVICE
constexpr
span
(
pointer
first
,
pointer
last
)
:
span
(
first
,
last
-
first
)
{}
template
<
std
::
size_t
N
>
CK_TILE_HOST_DEVICE
constexpr
span
(
element_type
(
&
arr
)[
N
])
noexcept
:
span
(
arr
,
N
)
{
}
template
<
std
::
size_t
N
>
CK_TILE_HOST_DEVICE
constexpr
span
(
std
::
array
<
value_type
,
N
>&
arr
)
noexcept
:
span
(
arr
.
data
(),
N
)
{
}
template
<
typename
Container
>
CK_TILE_HOST_DEVICE
constexpr
span
(
const
Container
&
container
)
:
span
(
container
.
data
(),
container
.
size
())
{
}
CK_TILE_HOST_DEVICE
constexpr
iterator
begin
()
const
noexcept
{
return
ptr_
;
}
CK_TILE_HOST_DEVICE
constexpr
const_iterator
cbegin
()
const
noexcept
{
return
begin
();
}
CK_TILE_HOST_DEVICE
constexpr
iterator
end
()
const
noexcept
{
return
begin
()
+
size
();
}
CK_TILE_HOST_DEVICE
constexpr
const_iterator
cend
()
const
noexcept
{
return
end
();
}
CK_TILE_HOST_DEVICE
constexpr
reference
front
()
const
{
return
*
begin
();
}
CK_TILE_HOST_DEVICE
constexpr
reference
back
()
const
{
return
*
(
--
end
());
}
CK_TILE_HOST_DEVICE
constexpr
reference
operator
[](
size_type
idx
)
const
{
return
*
(
begin
()
+
idx
);
}
CK_TILE_HOST_DEVICE
constexpr
pointer
data
()
const
noexcept
{
return
ptr_
;
}
CK_TILE_HOST_DEVICE
constexpr
size_type
size
()
const
noexcept
{
return
size_
;
}
private:
pointer
ptr_
;
size_type
size_
;
};
}
// namespace ck_tile
include/ck_tile/core/container/statically_indexed_array.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace
ck_tile
{
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
template
<
typename
T
,
index_t
N
>
using
statically_indexed_array
=
tuple_array
<
T
,
N
>
;
#else
// consider mark this struct as deprecated
template
<
typename
T
,
index_t
N
>
using
statically_indexed_array
=
array
<
T
,
N
>
;
#endif
// consider always use ck_tile::array for this purpose
#if 0
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
#endif
}
// namespace ck_tile
include/ck_tile/core/container/thread_buffer.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace
ck_tile
{
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
template
<
typename
T
,
index_t
N
>
using
thread_buffer
=
tuple_array
<
T
,
N
>
;
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_thread_buffer
(
Ts
&&
...
ts
)
{
return
make_tuple
(
ts
...);
}
#else
#if 0
template <typename T, index_t N>
using thread_buffer = array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
// clang-format off
template
<
typename
T_
,
index_t
N_
>
struct
thread_buffer
{
using
value_type
=
remove_cvref_t
<
T_
>
;
static
constexpr
index_t
N
=
N_
;
value_type
data
[
N
];
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE
constexpr
thread_buffer
()
:
data
{}
{}
CK_TILE_HOST_DEVICE
constexpr
thread_buffer
(
const
value_type
&
o
)
:
data
{
o
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
size
()
{
return
N
;
}
CK_TILE_HOST_DEVICE
auto
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
const
auto
&
get
()
const
{
return
data
;
}
CK_TILE_HOST_DEVICE
auto
&
get
(
index_t
i
)
{
return
data
[
i
];
}
CK_TILE_HOST_DEVICE
const
auto
&
get
(
index_t
i
)
const
{
return
data
[
i
];
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
operator
[](
index_t
i
)
const
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
operator
[](
index_t
i
)
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
operator
()(
index_t
i
)
{
return
get
(
i
);
}
// TODO: compatible
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
index_t
i
)
{
return
get
(
i
);
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
index_t
i
)
const
{
return
get
(
i
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
()
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
()
const
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
&
at
(
number
<
I
>
)
{
return
get
(
I
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
at
(
number
<
I
>
)
const
{
return
get
(
I
);
}
template
<
typename
X_
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
_get_as
()
const
{
using
X
=
remove_cvref_t
<
X_
>
;
constexpr
index_t
kSPerX
=
vector_traits
<
X
>::
vector_size
;
static_assert
(
N
%
kSPerX
==
0
);
union
{
thread_buffer
<
X_
,
N
/
kSPerX
>
data
{};
// tuple_array<value_type, kSPerX> sub_data;
value_type
sub_data
[
N
];
}
vx
;
static_for
<
0
,
N
,
1
>
{}(
[
&
](
auto
j
)
{
vx
.
sub_data
[
j
]
=
data
[
j
];
});
return
vx
.
data
;
}
template
<
typename
X_
,
index_t
Is
,
typename
std
::
enable_if
<
has_same_scalar_type
<
value_type
,
X_
>
::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
const
constexpr
remove_reference_t
<
X_
>
_get_as
(
number
<
Is
>
is
)
const
{
using
X
=
remove_cvref_t
<
X_
>
;
constexpr
index_t
kSPerX
=
vector_traits
<
X
>::
vector_size
;
union
{
X_
data
{};
tuple_array
<
value_type
,
kSPerX
>
sub_data
;
}
vx
;
static_for
<
0
,
kSPerX
,
1
>
{}(
[
&
](
auto
j
)
{
vx
.
sub_data
(
j
)
=
operator
[]((
is
*
number
<
sizeof
(
X_
)
/
sizeof
(
value_type
)
>
{})
+
j
);
});
return
vx
.
data
;
}
#if 0
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data;
tuple_array<value_type, kSPerX> sub_data;
} vx {x};
static_for<0, kSPerX, 1>{}(
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
}
#endif
#define TB_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
auto
&
get_as
()
{
TB_COMMON_AS
();
return
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
);}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_as
()
const
{
TB_COMMON_AS
();
if
constexpr
(
sizeof
(
value_type
)
<=
1
)
return
_get_as
<
Tx
>
();
// TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return
reinterpret_cast
<
const
thread_buffer
<
Tx
,
vx
>&>
(
data
);}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
auto
&
get_as
(
number
<
I
>
)
{
TB_COMMON_AS
();
return
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
).
get
(
number
<
I
>
{});}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_as
(
number
<
I
>
)
const
{
TB_COMMON_AS
();
if
constexpr
(
sizeof
(
value_type
)
<=
1
)
return
_get_as
<
Tx
>
(
number
<
I
>
{});
// TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return
reinterpret_cast
<
const
thread_buffer
<
Tx
,
vx
>&>
(
data
).
get
(
number
<
I
>
{});}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
index_t
i
,
const
Tx
&
x
)
{
TB_COMMON_AS
();
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
).
at
(
i
)
=
x
;
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
number
<
I
>
,
const
Tx
&
x
)
{
TB_COMMON_AS
();
reinterpret_cast
<
thread_buffer
<
Tx
,
vx
>&>
(
data
).
at
(
number
<
I
>
{})
=
x
;
}
#undef TB_COMMON_AS
};
// clang-format on
template
<
typename
>
struct
vector_traits
;
// specialization for array
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
thread_buffer
<
T
,
N
>>
{
using
scalar_type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
#endif
}
// namespace ck_tile
include/ck_tile/core/container/tuple.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
#include <initializer_list>
#ifndef CK_TILE_TUPLE_IMPL
#define CK_TILE_TUPLE_IMPL 1
#endif
namespace
ck_tile
{
namespace
impl
{
template
<
typename
T
,
index_t
N
>
struct
tuple_array_impl
;
}
template
<
typename
T
,
index_t
N
>
using
tuple_array
=
typename
impl
::
tuple_array_impl
<
T
,
N
>::
type
;
namespace
impl
{
// the place where content is stored
template
<
index_t
idx
,
typename
T
,
bool
is_empty
=
std
::
is_empty_v
<
T
>
>
struct
tuple_object
{
};
template
<
index_t
idx
,
typename
T
>
struct
tuple_object
<
idx
,
T
,
true
>
{
CK_TILE_HOST_DEVICE
constexpr
tuple_object
()
{}
#if CK_TILE_TUPLE_IMPL == 0
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
)
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
const
U
&
)
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&
)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
typename
U
,
typename
std
::
enable_if
<!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple_object
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
)
{
}
#endif
};
template
<
index_t
idx
,
typename
T
>
struct
tuple_object
<
idx
,
T
,
false
>
{
CK_TILE_HOST_DEVICE
constexpr
tuple_object
()
:
element
{}
{}
#if CK_TILE_TUPLE_IMPL == 0
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
e
)
:
element
(
std
::
forward
<
U
>
(
e
))
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
const
U
&
e
)
:
element
(
e
)
{
}
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&
e
)
:
element
(
e
)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
typename
U
,
typename
std
::
enable_if
<!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple_object
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_object
(
U
&&
e
)
:
element
(
std
::
forward
<
U
>
(
e
))
{
}
#endif
T
element
;
};
// NOTE: we return a instance(not a reference) if content is empty
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
T
getv
(
const
tuple_object
<
I
,
T
,
true
>&
)
{
return
{};
}
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
const
T
&
getv
(
const
tuple_object
<
I
,
T
,
false
>&
x
)
{
return
x
.
element
;
}
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
T
&
getv
(
tuple_object
<
I
,
T
,
false
>&
x
)
{
return
x
.
element
;
}
template
<
index_t
I
,
class
T
>
CK_TILE_HOST_DEVICE
constexpr
T
&&
getv
(
tuple_object
<
I
,
T
,
false
>&&
x
)
{
return
static_cast
<
T
&&>
(
x
.
element
);
}
template
<
typename
index_seq
,
typename
...
T
>
struct
tuple_base
;
template
<
index_t
...
I
,
typename
...
T
>
struct
tuple_base
<
sequence
<
I
...
>
,
T
...
>
:
tuple_object
<
I
,
T
>
...
{
CK_TILE_HOST_DEVICE
constexpr
tuple_base
()
=
default
;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#define _ILE() (std::initializer_list<U>{}.size() - 1)
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
std
::
initializer_list
<
U
>
us
)
:
tuple_object
<
I
,
T
>
(
static_cast
<
T
>
(
*
(
us
.
begin
()
+
(
I
>=
_ILE
()
?
_ILE
()
:
I
))))...
{
}
#undef _ILE
#endif
#if CK_TILE_TUPLE_IMPL == 0
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
explicit
tuple_base
(
U
&&
...
u
)
:
tuple_object
<
I
,
T
>
(
std
::
forward
<
U
>
(
u
))...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
explicit
tuple_base
(
const
U
&
...
u
)
:
tuple_object
<
I
,
T
>
(
u
)...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
explicit
tuple_base
(
U
&
...
u
)
:
tuple_object
<
I
,
T
>
(
u
)...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
tuple_base
<
sequence
<
I
...
>
,
U
...
>&&
u
)
:
tuple_object
<
I
,
T
>
(
getv
(
static_cast
<
tuple_object
<
I
,
U
>&&>
(
u
)))...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
const
tuple_base
<
sequence
<
I
...
>
,
U
...
>&
u
)
:
tuple_object
<
I
,
T
>
(
getv
(
static_cast
<
const
tuple_object
<
I
,
U
>&>
(
u
)))...
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
tuple_base
<
sequence
<
I
...
>
,
U
...
>&
u
)
:
tuple_object
<
I
,
T
>
(
getv
(
static_cast
<
tuple_object
<
I
,
U
>&>
(
u
)))...
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
class
U
,
typename
std
::
enable_if
<
sizeof
...(
I
)
==
1
&&
sizeof
...(
T
)
==
1
&&
!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple_base
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
U
&&
u
)
:
tuple_object
<
I
,
T
>
(
std
::
forward
<
U
>
(
u
))...
{
}
template
<
typename
...
U
,
typename
std
::
enable_if
<
sizeof
...(
U
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple_base
(
U
&&
...
u
)
:
tuple_object
<
I
,
T
>
(
std
::
forward
<
U
>
(
u
))...
{
static_assert
(
sizeof
...(
I
)
==
sizeof
...(
T
)
&&
sizeof
...(
I
)
==
sizeof
...(
U
),
"wrong! inconsistent size"
);
}
#endif
};
}
// namespace impl
template
<
class
...
T
>
struct
tuple
:
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
{
CK_TILE_HOST_DEVICE
static
constexpr
auto
size
()
{
return
sizeof
...(
T
);
}
using
base
=
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
T
)
>
,
T
...
>
;
CK_TILE_HOST_DEVICE
constexpr
tuple
()
=
default
;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template
<
typename
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
std
::
initializer_list
<
U
>
us
)
:
base
(
us
)
{
}
#endif
#if CK_TILE_TUPLE_IMPL == 0
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&&
...
u
)
:
base
(
std
::
forward
<
U
>
(
u
)...)
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
const
U
&
...
u
)
:
base
(
u
...)
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&
...
u
)
:
base
(
u
...)
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
tuple
<
U
...
>&&
u
)
:
base
(
static_cast
<
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
U
)
>
,
U
...
>&&>
(
u
))
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
const
tuple
<
U
...
>&
u
)
:
base
(
static_cast
<
const
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
U
)
>
,
U
...
>&>
(
u
))
{
}
template
<
class
...
U
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
tuple
<
U
...
>&
u
)
:
base
(
static_cast
<
impl
::
tuple_base
<
make_index_sequence
<
sizeof
...(
U
)
>
,
U
...
>&>
(
u
))
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template
<
typename
U
,
typename
std
::
enable_if
<
sizeof
...(
T
)
==
1
&&
!
std
::
is_same
<
remove_cvref_t
<
U
>,
tuple
>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&&
u
)
:
base
(
std
::
forward
<
U
>
(
u
))
{
}
template
<
typename
...
U
,
typename
std
::
enable_if
<
sizeof
...(
U
)
==
sizeof
...(
T
)
&&
sizeof
...(
U
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
tuple
(
U
&&
...
u
)
:
base
(
std
::
forward
<
U
>
(
u
)...)
{
}
#endif
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
bool
flag
=
true
;
static_for
<
0
,
sizeof
...(
T
),
1
>
{}([
&
flag
](
auto
i
)
{
flag
&=
is_static_v
<
remove_cvref_t
<
__type_pack_element
<
i
.
value
,
T
...
>>>
;
});
return
flag
;
}
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
()
const
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
(
number
<
I
>
)
const
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
()
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get
(
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
()
const
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
(
number
<
I
>
)
const
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
()
{
TP_COM_
();
return
impl
::
getv
<
I
>
(
*
this
);
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
at
(
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
operator
[](
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
operator
[](
number
<
I
>
)
const
{
TP_COM_
();
return
get
<
I
>
();
}
template
<
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
operator
()(
number
<
I
>
)
{
TP_COM_
();
return
get
<
I
>
();
}
// TODO: compatible
// below function should be used under tuple_array<> type, no extra check will perform here
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
()
{
return
reinterpret_cast
<
tuple_array
<
Tx
,
size
()
>&>
(
*
this
);
}
template
<
typename
Tx
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
()
const
{
return
reinterpret_cast
<
const
tuple_array
<
Tx
,
size
()
>&>
(
*
this
);
}
// below index is for index *AFTER* type convert, not before
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
(
number
<
I
>
)
{
TP_COM_
();
return
reinterpret_cast
<
tuple_array
<
Tx
,
size
()
>&>
(
*
this
).
at
(
number
<
I
>
{});
}
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
decltype
(
auto
)
get_as
(
number
<
I
>
)
const
{
TP_COM_
();
return
reinterpret_cast
<
const
tuple_array
<
Tx
,
size
()
>&>
(
*
this
).
at
(
number
<
I
>
{});
}
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
template
<
typename
Tx
,
index_t
I
>
CK_TILE_HOST_DEVICE
constexpr
void
set_as
(
number
<
I
>
,
const
Tx
&
x
)
{
TP_COM_
();
reinterpret_cast
<
tuple_array
<
Tx
,
size
()
>&>
(
*
this
).
at
(
number
<
I
>
{})
=
x
;
}
// clang-format on
#undef TP_COM_
};
template
<
typename
>
struct
vector_traits
;
// specialization for array
template
<
typename
...
T
>
struct
vector_traits
<
tuple
<
T
...
>>
{
using
scalar_type
=
__type_pack_element
<
0
,
T
...
>
;
static
constexpr
index_t
vector_size
=
sizeof
...(
T
);
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
==
(
const
tuple
<
Xs
...
>&
a
,
const
tuple
<
Xs
...
>&
b
)
{
bool
same
=
true
;
static_for
<
0
,
sizeof
...(
Xs
),
1
>
{}([
&
](
auto
i
)
{
if
(
a
[
i
]
!=
b
[
i
])
{
same
=
false
;
}
});
return
same
;
}
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
bool
operator
!=
(
const
tuple
<
Xs
...
>&
a
,
const
tuple
<
Xs
...
>&
b
)
{
return
!
(
a
==
b
);
}
template
<
typename
...
Xs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
// here xs is always a lvalue as function arg
// Xs may deduced as (e.g try to pass in a integer in following cases)
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
// 2). if pass in a const lvalue -> Xs is "const int &"
// 3). if pass in a non-const lvalue -> Xs is "int &"
// so the return type of std::forward will dependes on Xs
// 1). std::forward -> int&&
// 2). std::forward -> const int&
// 3). std::forward -> int&
return
tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template
<
typename
...
Args
>
constexpr
tuple
<
Args
&
...
>
tie
(
Args
&
...
args
)
noexcept
{
return
{
args
...};
}
template
<
typename
X
,
typename
Y
>
struct
tuple_concat
;
template
<
typename
...
Xs
,
typename
...
Ys
>
struct
tuple_concat
<
tuple
<
Xs
...
>
,
tuple
<
Ys
...
>>
{
using
type
=
tuple
<
Xs
...,
Ys
...
>
;
};
namespace
impl
{
// be very careful using this type (because we want the internal type)
// template deduction will fail if infering the inner type
// e.g.
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
// -> compiler will fail to deduce this type, because this is under non-deduced context
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
// contexts")
//
// -> use this instead
// template<typename Tup> void foo(const Tup&) {}
template
<
typename
T
,
index_t
N
>
struct
tuple_array_impl
{
using
type
=
typename
tuple_concat
<
typename
tuple_array_impl
<
T
,
N
/
2
>::
type
,
typename
tuple_array_impl
<
T
,
N
-
N
/
2
>::
type
>::
type
;
};
template
<
typename
T
>
struct
tuple_array_impl
<
T
,
0
>
{
using
type
=
tuple
<>
;
};
template
<
typename
T
>
struct
tuple_array_impl
<
T
,
1
>
{
using
type
=
tuple
<
T
>
;
};
}
// namespace impl
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_tuple
(
F
&&
f
,
number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
is
)
{
return
make_tuple
(
f
(
is
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
template
<
typename
F
,
index_t
N
>
CK_TILE_HOST_DEVICE
constexpr
auto
generate_tie
(
F
&&
f
,
number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
is
)
{
return
tie
(
f
(
is
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template
<
typename
...
X
,
typename
...
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple_of_reference
(
const
tuple
<
X
&
...
>&
tx
,
const
tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
template
<
typename
...
X
,
typename
...
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple
(
const
tuple
<
X
...
>&
tx
,
const
tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
// Support any number of tuples to concat (also 1)
template
<
typename
...
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple
(
const
tuple
<
X
...
>&
tx
)
{
return
tx
;
}
template
<
typename
...
X
,
typename
...
Tuples
>
CK_TILE_HOST_DEVICE
constexpr
auto
concat_tuple
(
const
tuple
<
X
...
>&
tx
,
const
Tuples
&
...
tuples
)
{
return
concat_tuple
(
tx
,
concat_tuple
(
tuples
...));
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
template
<
typename
F
,
typename
X
,
typename
Y
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}),
y
.
at
(
number
<
Is
>
{}))...);
}
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples_impl
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
const
Z
&
z
,
sequence
<
Is
...
>
)
{
return
make_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}),
y
.
at
(
number
<
Is
>
{}),
z
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
template
<
typename
F
,
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
,
const
Y
&
y
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
y
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
template
<
typename
F
,
typename
X
,
typename
Y
,
typename
Z
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tuples
(
F
f
,
const
X
&
x
,
const
Y
&
y
,
const
Z
&
z
)
{
return
detail
::
transform_tuples_impl
(
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
{
return
t
;
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
T
&
t
)
{
return
make_tuple
(
t
);
}
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<
Ts
...
>&
t
)
{
if
constexpr
(
Depth
==
MaxDepth
)
{
return
t
;
}
else
{
return
unpack
(
[
&
](
auto
&&
...
ts
)
{
return
concat_tuple
(
unroll_nested_tuple
<
Depth
+
1
,
MaxDepth
>
(
ts
)...);
},
t
);
}
}
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_reverse
(
const
tuple
<
Ts
...
>&
t
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
number
<
tuple
<
Ts
...
>::
size
()
-
i
-
1
>
;
return
t
.
at
(
Idx
{});
},
number
<
tuple
<
Ts
...
>::
size
()()
>
{});
}
// Reduce tuple values in specific range using Function
template
<
index_t
Idx
,
index_t
End
,
typename
F
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_reduce
(
F
&&
f
,
const
tuple
<
Ts
...
>&
t
)
{
static_assert
(
Idx
<
End
,
"Wrong parameters for tuple_reduce"
);
if
constexpr
(
Idx
+
1
==
End
)
{
return
t
.
at
(
number
<
Idx
>
{});
}
else
{
return
f
(
t
.
at
(
number
<
Idx
>
{}),
tuple_reduce
<
Idx
+
1
,
End
>
(
f
,
t
));
}
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
is_nested_tuple
(
const
tuple
<
Ts
...
>&
)
{
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
template
<
index_t
depth
=
0
,
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_depth
(
const
T
&
)
{
return
depth
;
}
template
<
index_t
depth
=
0
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
tuple_depth
(
const
tuple
<
Ts
...
>&
)
{
return
max
(
tuple_depth
<
depth
+
1
>
(
Ts
{})...);
}
template
<
typename
...
Seqs
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array_of_array
(
tuple
<
Seqs
...
>
t_of_s
)
{
constexpr
index_t
n0
=
sizeof
...(
Seqs
);
constexpr
index_t
max_n1
=
[
&
]
{
index_t
max_n1_
=
0
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i0
)
{
constexpr
index_t
n1
=
t_of_s
[
i0
].
size
();
max_n1_
=
max_n1_
<
n1
?
n1
:
max_n1_
;
});
return
max_n1_
;
}();
array
<
array
<
index_t
,
max_n1
>
,
n0
>
a_of_a
{{
-
1
}};
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i0
)
{
constexpr
index_t
n1
=
t_of_s
[
i0
].
size
();
static_for
<
0
,
n1
,
1
>
{}([
&
](
auto
i1
)
{
a_of_a
(
i0
)(
i1
)
=
t_of_s
[
i0
][
i1
];
});
});
return
a_of_a
;
}
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template
<
typename
...
Ys
,
typename
X
,
std
::
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
return
y
;
}
template
<
typename
...
Ys
,
typename
X
,
std
::
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
return
y
;
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
x
[
i
]
+
y
[
i
];
});
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
x
[
i
]
-
y
[
i
];
});
return
r
;
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
x
[
i
]
*
y
[
i
];
});
return
r
;
}
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
Y
a
,
const
tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
r
[
i
]
=
a
*
x
[
i
];
});
return
r
;
}
// MultiIndex = MultiIndex * scalar
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
Y
a
)
{
return
a
*
x
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
/
y
[
i
];
},
number
<
NSize
>
{});
}
}
// namespace ck_tile
#include <tuple>
// WARNING: needed by compiler for C++ structured binding support only, don't use this
namespace
std
{
template
<
typename
...
Ts
>
struct
tuple_size
<
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
integral_constant
<
std
::
size_t
,
sizeof
...(
Ts
)
>
{
};
template
<
std
::
size_t
I
,
typename
...
Ts
>
struct
tuple_element
<
I
,
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
tuple_element
<
I
,
std
::
tuple
<
Ts
...
>>
{
};
template
<
typename
...
Ts
>
struct
tuple_size
<
const
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
integral_constant
<
std
::
size_t
,
sizeof
...(
Ts
)
>
{
};
template
<
std
::
size_t
I
,
typename
...
Ts
>
struct
tuple_element
<
I
,
const
ck_tile
::
tuple
<
Ts
...
>>
:
std
::
tuple_element
<
I
,
const
std
::
tuple
<
Ts
...
>>
{
};
}
// namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif
include/ck_tile/core/numeric/bfloat16.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#pragma once
namespace
ck_tile
{
enum
class
bf16_rounding_mode
{
standard
=
0
,
// rtn
truncate_with_nan
,
truncate
,
};
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_raw
(
float
f
,
constant
<
rounding
>
=
{});
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
double_to_bf16_raw
(
double
f
,
constant
<
rounding
>
=
{});
CK_TILE_HOST_DEVICE
constexpr
float
bf16_to_float_raw
(
uint16_t
x
);
CK_TILE_HOST_DEVICE
constexpr
double
bf16_to_double_raw
(
uint16_t
x
);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use __hip_bfloat16 as struct
struct
alignas
(
2
)
bfloat16_t
{
using
raw_type
=
uint16_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
bit_cast
(
raw_type
x
)
{
bfloat16_t
y
;
y
.
data
=
x
;
return
y
;
}
// constructor
constexpr
bfloat16_t
()
:
data
()
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
float
&
x
)
:
data
(
float_to_bf16_raw
(
x
))
{}
// construct from double
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
double
&
x
)
:
data
(
double_to_bf16_raw
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
int
&
x
)
:
data
(
float_to_bf16_raw
(
static_cast
<
float
>
(
x
)))
{}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
bfloat16_t
(
const
unsigned
int
&
x
)
:
data
(
float_to_bf16_raw
(
static_cast
<
float
>
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
bf16_to_float_raw
(
data
);
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
double
()
const
{
return
bf16_to_double_raw
(
data
);
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
bf16_to_float_raw
(
data
));
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
template
<
typename
>
struct
native_t
;
template
<
>
struct
native_t
<
bfloat16_t
>
{
using
type
=
ushort
;
};
using
bf16_t
=
bfloat16_t
;
using
bf16_raw_t
=
typename
bf16_t
::
raw_type
;
#else
using
bfloat16_t
=
ushort
;
using
bf16_t
=
bfloat16_t
;
using
bf16_raw_t
=
uint16_t
;
#endif
// round to nearest
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_rtn_raw
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
if
(
~
u
.
int32
&
0x7f800000
)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u
.
int32
+=
0x7fff
+
((
u
.
int32
>>
16
)
&
1
);
// Round to nearest, round to even
}
else
if
(
u
.
int32
&
0xffff
)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u
.
int32
|=
0x10000
;
// Preserve signaling NaN
}
return
uint16_t
(
u
.
int32
>>
16
);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
return
uint16_t
(
u
.
int32
>>
16
)
|
(
!
(
~
u
.
int32
&
0x7f800000
)
&&
(
u
.
int32
&
0xffff
));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_raw
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
bf16_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_raw
(
float
f
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
return
float_to_bf16_truc_raw
(
f
);
}
template
<
bf16_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
constexpr
uint16_t
double_to_bf16_raw
(
double
f
,
constant
<
rounding
>
)
{
return
float_to_bf16_raw
(
static_cast
<
float
>
(
f
),
constant
<
rounding
>
{});
}
CK_TILE_HOST_DEVICE
constexpr
float
bf16_to_float_raw
(
uint16_t
x
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
}
CK_TILE_HOST_DEVICE
constexpr
double
bf16_to_double_raw
(
uint16_t
x
)
{
return
static_cast
<
double
>
(
bf16_to_float_raw
(
x
));
}
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
bfloat16_t
float_to_bf16
(
float
f
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bfloat16_t
>
(
float_to_bf16_raw
(
f
,
constant
<
rounding
>
{}));
}
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
constexpr
bfloat16_t
double_to_bf16
(
double
f
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bfloat16_t
>
(
double_to_bf16_raw
(
f
,
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
constexpr
float
bf16_to_float
(
bfloat16_t
x
)
{
return
bf16_to_float_raw
(
bit_cast
<
uint16_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
constexpr
double
bf16_to_double
(
bfloat16_t
x
)
{
return
static_cast
<
double
>
(
bf16_to_float_raw
(
x
));
}
template
<
bf16_rounding_mode
rounding
=
static_cast
<
bf16_rounding_mode
>(
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bfloat16_t
constexpr
fp16_to_bf16
(
half_t
f
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bfloat16_t
>
(
float_to_bf16_raw
(
static_cast
<
float
>
(
f
),
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
constexpr
half_t
bf16_to_fp16
(
bfloat16_t
x
)
{
return
static_cast
<
fp16_t
>
(
static_cast
<
float
>
(
x
));
}
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
bfloat16_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
min
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x0080
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
lowest
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0xff7f
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
max
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7f7f
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
epsilon
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x1000
));
}
// maximum rounding error
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeeeeee mmmmmmm
// 0 01111110 0000000 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
round_error
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x3f00
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
infinity
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7f80
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
quiet_NaN
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7FFF
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
signaling_NaN
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x7FFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
denorm_min
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0x0001
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bfloat16_t
zero
()
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
0
));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
#endif
// math
CK_TILE_HOST_DEVICE
bfloat16_t
abs
(
const
bfloat16_t
&
x
)
{
return
bit_cast
<
bfloat16_t
>
(
static_cast
<
bf16_raw_t
>
(
bit_cast
<
bf16_raw_t
>
(
x
)
&
0x7fff
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bfloat16_t
&
x
)
{
uint16_t
xx
=
bit_cast
<
bf16_raw_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
}
CK_TILE_DEVICE
bfloat16_t
sqrt
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
exp2
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bfloat16_t
log
(
bfloat16_t
x
)
{
return
static_cast
<
bfloat16_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
}
// namespace ck_tile
include/ck_tile/core/numeric/float8.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace
ck_tile
{
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum
class
fp8_rounding_mode
{
standard
=
0
,
stochastic
};
/*
* ______________NANOO_________________ | ______________IEEE________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
uint8_t
float_to_fp8_raw
(
float
,
constant
<
rounding
>
=
{});
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
uint8_t
float_to_bf8_raw
(
float
,
constant
<
rounding
>
=
{});
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
uint8_t
);
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
uint8_t
);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
struct
alignas
(
1
)
float8_e4m3_t
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
float8_e4m3_t
bit_cast
(
raw_type
x
)
{
float8_e4m3_t
y
;
y
.
data
=
x
;
return
y
;
}
// constructor
constexpr
float8_e4m3_t
()
:
data
()
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e4m3_t
(
const
float
&
x
)
:
data
(
float_to_fp8_raw
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e4m3_t
(
const
int
&
x
)
:
data
(
float_to_fp8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e4m3_t
(
const
unsigned
int
&
x
)
:
data
(
float_to_fp8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
fp8_to_float_raw
(
data
);
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
fp8_to_float_raw
(
data
));
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
using
fp8_t
=
float8_e4m3_t
;
using
fp8_raw_t
=
typename
fp8_t
::
raw_type
;
struct
alignas
(
1
)
float8_e5m2_t
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
float8_e5m2_t
bit_cast
(
raw_type
x
)
{
float8_e5m2_t
y
;
y
.
data
=
x
;
return
y
;
}
// constructor
constexpr
float8_e5m2_t
()
:
data
()
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e5m2_t
(
const
float
&
x
)
:
data
(
float_to_bf8_raw
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e5m2_t
(
const
int
&
x
)
:
data
(
float_to_bf8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
float8_e5m2_t
(
const
unsigned
int
&
x
)
:
data
(
float_to_bf8_raw
(
static_cast
<
float
>
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
bf8_to_float_raw
(
data
);
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
bf8_to_float_raw
(
data
));
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
using
bf8_t
=
float8_e5m2_t
;
using
bf8_raw_t
=
typename
bf8_t
::
raw_type
;
template
<
typename
>
struct
native_t
;
template
<
>
struct
native_t
<
fp8_t
>
{
using
type
=
_BitInt
(
8
);
};
template
<
>
struct
native_t
<
bf8_t
>
{
using
type
=
unsigned
_BitInt
(
8
);
};
#else
using
fp8_t
=
_BitInt
(
8
);
using
fp8_raw_t
=
uint8_t
;
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_raw_t
=
uint8_t
;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
// original type exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr
Y
nan_code
=
numeric
<
Y
>::
quiet_NaN
();
// __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
#else
constexpr
Y
nan_code
=
0x80
;
#endif
constexpr
uint32_t
nan_mask
=
numeric_traits
<
X
>::
nan_mask
;
// convert to bitwise
using
T_bitwise
=
typename
numeric_traits
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
head
=
x_bitwise
&
numeric_traits
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
numeric_traits
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
numeric_traits
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
numeric_traits
<
X
>::
bias
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
if
constexpr
(
negative_zero_nan
)
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
nan_code
;
}
else
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
0
));
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in_mant
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out_exponent
==
0
)
{
if
((
1
<<
in_mant
)
&
mantissa
)
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
if
((
1
<<
(
in_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa
>>=
(
in_mant
-
out_mant
);
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
out_exponent
=
max_exp
;
}
else
{
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
signed_inf
));
}
}
// check if x is 0.0 or -0.0
if
(
out_exponent
==
0
&&
mantissa
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
))));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
((
sign
<<
(
out_exp
+
out_mant
))
|
(
out_exponent
<<
out_mant
)
|
mantissa
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
run_cast_from_f8
(
X
x
)
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
uint8_t
x_raw
=
__builtin_bit_cast
(
uint8_t
,
x
);
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
numeric_traits
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
numeric_traits
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
numeric_traits
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
numeric_traits
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
numeric_traits
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x_raw
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x_raw
>>
(
in_exp
+
in_mant
);
uint32_t
mantissa
=
x_raw
&
((
1
<<
in_mant
)
-
1
);
int
exponent
=
(
x_raw
&
0x7F
)
>>
in_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
out_exp
-
1
))
-
(
1
<<
(
in_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwise
retval
;
if
constexpr
(
negative_zero_nan
)
{
if
(
x_raw
==
nan_code
)
return
NaN
;
}
else
{
if
(
x_raw
==
nan_code
)
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
numeric_traits
<
Y
>::
mant
==
10
)
&&
(
numeric_traits
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x_raw
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
out_mant
-
in_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
out_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
cast_from_f8
(
X
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
// namespace impl
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_sr_raw
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_sr_raw
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
// clang-format off
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_fp8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_fp8_sr_raw
(
x
);
else
return
fp8_raw_t
{
0
};
}
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_bf8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_bf8_sr_raw
(
x
);
else
return
bf8_raw_t
{
0
};
}
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
fp8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
fp8_t
>
(
x
));
#endif
}
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
bf8_t
>
(
x
));
#endif
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
// clang-format on
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
fp8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
8
;
#else
static
constexpr
int
bias
=
7
;
#endif
};
template
<
>
struct
numeric_traits
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static
constexpr
int
bias
=
16
;
#else
static
constexpr
int
bias
=
15
;
// IEEE
#endif
};
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
fp8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x08
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
lowest
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xff
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
max
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7f
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
epsilon
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x20
));
}
// maximum rounding error
// bin : 7 6543 210
// bits: s eeee mmm
// 0 0110 000 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
round_error
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x30
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
infinity
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x80
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
quiet_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x80
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
signaling_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x80
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
denorm_min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
zero
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0
));
}
};
template
<
>
struct
numeric
<
bf8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x04
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
lowest
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xff
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
max
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7f
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
epsilon
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x34
));
}
// maximum rounding error
// bin : 7 65432 10
// bits: s eeeee mm
// 0 01110 00 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
round_error
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x38
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
infinity
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x80
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
quiet_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x80
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
signaling_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x80
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
denorm_min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
zero
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bf8_t
)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t
abs
(
const
fp8_t
&
x
)
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
fp8_raw_t
>
(
x
)
&
0x7f
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
fp8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
exp2
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST_DEVICE
bf8_t
abs
(
const
bf8_t
&
x
)
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
bf8_raw_t
>
(
x
)
&
0x7f
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bf8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
exp2
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
}
// namespace ck_tile
include/ck_tile/core/numeric/half.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace
ck_tile
{
using
fp16_hip_t
=
_Float16
;
// most of hip internal function use this type
using
fp16_raw_t
=
uint16_t
;
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_float_hip
(
const
fp16_hip_t
&
x
);
CK_TILE_HOST_DEVICE
constexpr
double
fp16_to_double_hip
(
const
fp16_hip_t
&
x
);
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
float_to_fp16_hip
(
const
float
&
x
);
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
double_to_fp16_hip
(
const
double
&
x
);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use fp16_hip_t as interchangable data type for float16
struct
alignas
(
2
)
half_t
{
using
raw_type
=
fp16_raw_t
;
raw_type
data
;
CK_TILE_HOST_DEVICE
static
constexpr
half_t
bit_cast
(
raw_type
x
)
{
half_t
y
;
y
.
data
=
x
;
return
y
;
}
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
to_fp16
()
const
{
return
ck_tile
::
bit_cast
<
fp16_hip_t
>
(
data
);
}
// constructor
constexpr
half_t
()
:
data
{}
{}
// construct from HIP half
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
fp16_hip_t
&
x
)
:
data
(
ck_tile
::
bit_cast
<
raw_type
>
(
x
))
{}
// construct from float
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
float
&
x
)
:
half_t
(
float_to_fp16_hip
(
x
))
{}
// construct from double
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
double
&
x
)
:
half_t
(
double_to_fp16_hip
(
x
))
{}
// construct from int
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
int
&
x
)
:
half_t
(
static_cast
<
fp16_hip_t
>
(
__int2half_rn
(
x
)))
{}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit
constexpr
half_t
(
const
unsigned
int
&
x
)
:
half_t
(
static_cast
<
fp16_hip_t
>
(
__uint2half_rn
(
x
)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
float
()
const
{
return
fp16_to_float_hip
(
to_fp16
());
}
// cast to double
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
double
()
const
{
return
fp16_to_double_hip
(
to_fp16
());
}
// cast to int
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
int
()
const
{
return
static_cast
<
int
>
(
fp16_to_float_hip
(
to_fp16
()));
}
CK_TILE_HOST_DEVICE
explicit
constexpr
operator
fp16_hip_t
()
const
{
return
ck_tile
::
bit_cast
<
fp16_hip_t
>
(
data
);
}
// internal access
CK_TILE_HOST_DEVICE
constexpr
raw_type
&
get
()
{
return
data
;
}
CK_TILE_HOST_DEVICE
constexpr
raw_type
get
()
const
{
return
data
;
}
};
template
<
typename
>
struct
native_t
;
template
<
>
struct
native_t
<
half_t
>
{
using
type
=
_Float16
;
};
using
fp16_t
=
half_t
;
using
fp16_raw_t
=
typename
half_t
::
raw_type
;
#else
using
fp16_t
=
_Float16
;
using
half_t
=
_Float16
;
using
fp16_raw_t
=
ushort
;
#endif
// conversions
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_float_hip
(
const
fp16_hip_t
&
x
)
{
// return __half2float(x);
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
double
fp16_to_double_hip
(
const
fp16_hip_t
&
x
)
{
return
static_cast
<
double
>
(
fp16_to_float_hip
(
x
));
}
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
float_to_fp16_hip
(
const
float
&
x
)
{
return
__float2half
(
x
);
// return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
double_to_fp16_hip
(
const
double
&
x
)
{
// return __float2half(x);
return
static_cast
<
fp16_hip_t
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_float
(
const
half_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
float
fp16_to_double
(
const
half_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
half_t
float_to_fp16
(
const
float
&
x
)
{
return
static_cast
<
half_t
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
half_t
double_to_fp16
(
const
double
&
x
)
{
return
static_cast
<
half_t
>
(
x
);
}
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
half_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
half_t
min
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x0400
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
lowest
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0xFBFF
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
max
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7BFF
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
half_t
epsilon
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x1800
));
}
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeee mmmmmmmmmm
// 0 01110 0000000000 (0.5)
//
CK_TILE_HOST_DEVICE
static
constexpr
half_t
round_error
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x3800
));
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
infinity
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7C00
));
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
half_t
quiet_NaN
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7FFF
));
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
half_t
signaling_NaN
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x7FFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
half_t
denorm_min
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0x0001
));
}
CK_TILE_HOST_DEVICE
static
constexpr
half_t
zero
()
{
return
bit_cast
<
half_t
>
(
static_cast
<
fp16_raw_t
>
(
0
));
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
int
bias
=
15
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// arithmetic
CK_TILE_DEVICE
bool
operator
==
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__heq
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
!=
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hne
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
<
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hlt
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
<=
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hle
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
>
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hgt
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
CK_TILE_DEVICE
bool
operator
>=
(
const
half_t
&
x
,
const
half_t
&
y
)
{
return
__hge
(
x
.
to_fp16
(),
y
.
to_fp16
());
}
#if 0
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST
,
half_t
)
#endif
// math
CK_TILE_HOST_DEVICE
half_t
abs
(
const
half_t
&
x
)
{
return
bit_cast
<
half_t
>
(
x
.
get
()
&
0x7fff
);
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
half_t
&
x
)
{
uint16_t
xx
=
x
.
get
();
return
(
xx
&
0x7FFF
)
>
0x7C00
;
}
CK_TILE_DEVICE
half_t
sqrt
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__expf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
exp2
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
exp2f
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
half_t
log
(
half_t
x
)
{
return
static_cast
<
half_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/integer.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace
ck_tile
{
using
index_t
=
int32_t
;
using
long_index_t
=
int64_t
;
using
int8_t
=
int8_t
;
}
// namespace ck_tile
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment