Unverified Commit db376dd8 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

introducing ck_tile! (#1216)

* enable gfx940

* switch between intrinsic mfma routines on mi100/200 and mi300

* fix mfma_int8 on MI300

* disable 2 int8 examples on MI300

* Update cmake-ck-dev.sh

* restore gitignore file

* modify Jenkinsfile to the internal repo

* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx

Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0

)

---
updated-dependencies:
- dependency-name: rocm-docs-core
  dependency-type: direct:production
  update-type: version-update:semver-minor
...
Signed-off-by: default avatardependabot[bot] <support@github.com>

* initial enablement of gfx950

* fix clang format

* disable examples 31 and 41 int8 on gfx950

* add code

* fix build wip

* fix xx

* now can build

* naming

* minor fix

* wip fix

* fix macro for exp2; fix warpgemm a/b in transposedC

* unify as tuple_array

* Update the required Python version to 3.9

* Update executable name in test scripts

* re-structure tuple/array to avoid spill

* Merge function templates

* Fix format

* Add constraint to array<> ctor

* Re-use function

* Some minor changes

* remove wrong code in store_raw()

* fix compile issue in transpose

* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'

* let more integral_constant->constant, and formating

* make sure thread_buffer can be tuple/array

* temp fix buffer_store spill

* not using custom data type by default, now we can have ISA-level same code as opt_padding

* fix compile error, fp8 not ready now

* fix fp8 duplicated move/shift/and/or problem

* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode

* fix scratch in fp8 kernel

* update some readme

* fix merge from upstream

* sync with upstream

* sync upstream again

* sync 22

* remove unused

* fix clang-format

* update README of ck_tile example

* fix several issue

* let python version to be 3.8 as minimal

* remove ck_tile example from default cmake target like all/install/check

* remove mistake

* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg

* fix some bug in group-mode masking and codegen. update README

* F8 quantization for FMHA forward (#1224)

* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline

* Add element function to fmha api

* Adjust P elementwise function

* Fix bug of elementwise op, our elementwise op is not inout

* Add some elementwise op, prepare to quantization

* Let generate.py can generate different elementwise function

* To prevent compiler issue, remove the elementwise function we have not used.

* Remove f8 pipeline, we should share the same pipeline even in f8

* Remove remove_cvref_t

* Avoid warning

* Fix wrong fp8 QK/KV block gemm setting

* Check fp8 rounding error in check_err()

* Set fp8 rounding error for check_err()

* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode

* 1. codgen the f8 api and kernel
2. f8 host code

* prevent warning in filter mode

* Remove not-in-use elementwise function kargs

* Remove more not-in-use elementwise function kargs

* Small refinements in C++ source files

* Use conditional_t<> to simplify code

* Support heterogeneous argument for binary function types

* Re-use already-existing scales<> functor template

* Fix wrong value produced by saturating

* Generalize the composes<> template

* Unify saturates<> implementation

* Fix type errors in composes<>

* Extend less_equal<>

* Reuse the existing template less_equal<> in check_err()

* Add equal<float> & equal<double>

* Rename check_err() parameter

* Rename check_err() parameter

* Add FIXME comment for adding new macro in future

* Remove unnecessary cast to void

* Eliminate duplicated code

* Avoid dividing api pool into more than 2 groups

* Use more clear variable names

* Use affirmative condition in if stmt

* Remove blank lines

* Donot perfect forwarding in composes<>

* To fix compile error, revert generate.py back to 4439cc107dd90302d68a6494bdd33113318709f8

* Fix bug of p element function

* Add compute element op to host softmax

* Remove element function in api interface

* Extract user parameter

* Rename pscale and oscale variable

* rename f8 to fp8

* rename more f8 to fp8

* Add pipeline::operator() without element_functor

* 1. Remove deprecated pipeline enum
2. Refine host code parameter

* Use quantization range as input

* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description

* Refine description

* prevent early return

* unify _squant kernel name in cpp, update README

* Adjust the default range.

* Refine error message and bias range

* Add fp8 benchmark and smoke test

* fix fp8 swizzle_factor=4 case

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>

---------
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarJing Zhang <jizha@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent dd34ab6e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct space_filling_curve
{
static constexpr index_t TensorSize =
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
static_assert(0 < TensorSize,
"space_filling_curve should be used to access a non-empty tensor");
static constexpr index_t nDim = TensorLengths::size();
using Index = multi_index<nDim>;
static constexpr index_t ScalarPerVector =
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
static constexpr auto dim_access_order = DimAccessOrder{};
static constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(ordered_access_lengths)),
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
make_tuple(sequence<0>{}));
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
}
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
number<AccessIdx1dTail>)
{
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
"1D index out of range");
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
"1D index out of range");
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
return idx_tail - idx_head;
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
{
#if 0
/*
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
#else
constexpr auto access_strides =
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
constexpr auto idx_1d = number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value;
auto id = 0;
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
id = res / access_strides[kdim].value;
res -= id * access_strides[kdim].value;
});
return id;
};
constexpr auto id = compute_index_impl(idim);
return number<id>{};
};
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
#endif
constexpr auto forward_sweep = [&]() {
statically_indexed_array<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto idim) {
index_t tmp = ordered_access_idx[I0];
static_for<1, idim, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(idim) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate multi-dim tensor index
auto idx_md = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
ScalarsPerAccess{};
}();
return idx_md;
}
// FIXME: rename this function
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
{
constexpr auto idx = get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
struct __attribute__((packed)) buffer_resource
{
const void* ptr;
uint32_t range;
uint32_t config;
};
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
return __builtin_bit_cast(int32x4_t, res);
}
// TODO: glc/slc/...
template <index_t bytes>
struct buffer_load;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template <>
struct buffer_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <index_t bytes>
struct buffer_load_if;
template <>
struct buffer_load_if<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
template <index_t bytes>
struct buffer_store;
template <>
struct buffer_store<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
asm volatile(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
asm volatile(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_store<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
}
};
template <index_t bytes>
struct buffer_store_if;
template <>
struct buffer_store_if<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 8);
auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <>
struct buffer_store_if<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// clang-format off
namespace impl{
// can't use "+v" since there could be potential extra move(read/write)
// use "v" can help remove such duplicated moves
// besides, fake this as "memory" operation to force later valu after this fence
// TODO: may have scratch (because this is memory?)
// need to reduce extra move inside compiler
template<index_t N>
CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b)
{
static_for<0, b.size(), 1>{}([&](auto i){
asm volatile(" " : : "v"(b.get(i)) : "memory");
});
}
#if 1
// below specialization just merge size() of dwords into single section
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array<float, 2>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array<float, 3>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array<float, 4>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array<float, 8>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array<float, 16>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
"v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
"v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory");
}
template<>
CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array<float, 32>& b)
{
asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
"v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
"v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
"v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})),
"v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})),
"v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})),
"v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})),
"v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory");
}
#endif
CK_TILE_DEVICE void insert_dummy_dep() {}
template<typename T>
CK_TILE_DEVICE void insert_dummy_dep(T & buffer)
{
// TODO: indeed we expect T to be multiple of dword. subdword is always buggy
using da_type = array<float, (sizeof(T) + 3) / 4>;
auto & dummy = reinterpret_cast<da_type&>(buffer);
insert_dummy_dep_per_dword(dummy);
}
template<typename Tx, typename... Ty>
CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
{
insert_dummy_dep(bx);
insert_dummy_dep(by...);
}
}
// clang-format on
template <typename... T>
CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
impl::insert_dummy_dep(o...);
}
CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
CK_TILE_DEVICE_EXTERN int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
CK_TILE_DEVICE_EXTERN int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
// buffer load i16
CK_TILE_DEVICE_EXTERN int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
CK_TILE_DEVICE_EXTERN int16x2_t
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
CK_TILE_DEVICE_EXTERN int16x4_t
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
// buffer load i32
CK_TILE_DEVICE_EXTERN int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
CK_TILE_DEVICE_EXTERN int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
CK_TILE_DEVICE_EXTERN int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// buffer load fp16
CK_TILE_DEVICE_EXTERN _Float16
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
CK_TILE_DEVICE_EXTERN fp16x2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
CK_TILE_DEVICE_EXTERN fp16x4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// buffer load fp32
CK_TILE_DEVICE_EXTERN float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
CK_TILE_DEVICE_EXTERN fp32x2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
CK_TILE_DEVICE_EXTERN fp32x4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// buffer store i8
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
// buffer store i16
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
// buffer store i32
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// buffer store fp16
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// buffer store fp32
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// buffer atomic-add fp16
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
fp16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// buffer atomic-add fp32
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-max fp64
CK_TILE_DEVICE_EXTERN double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource
int voffset, // dst_thread_addr_offset
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0)
{
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset)
: "memory");
}
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// memory coherency bit for buffer store/load instruction
// check ISA manual for each GFX target
// e.g. for
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
// page 67~68
enum struct amd_buffer_coherence_enum
{
coherence_default = 0, // default value
glc = 1,
slc = 2,
glc_slc = 3,
};
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<int8_t, N>
amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
using rtn_type = thread_buffer<int8_t, N>;
if constexpr(N == 1)
{
return bit_cast<rtn_type>(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 4)
{
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 8)
{
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 32)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
thread_buffer<int32_t, 8> tmp;
tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 64)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp2 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp3 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int32_t),
static_cast<index_t>(coherence));
thread_buffer<int32_t, 16> tmp;
tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
tmp.template get_as<int32x4_t>()(number<2>{}) = tmp2;
tmp.template get_as<int32x4_t>()(number<3>{}) = tmp3;
return bit_cast<rtn_type>(tmp);
}
}
#ifndef BUFFER_LOAD_USE_INLINEASM
#define BUFFER_LOAD_USE_INLINEASM 0
#endif
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
if constexpr(std::is_same<T, float>::value) // fp32
{
if constexpr(N == 1)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 4)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 8)
{
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return tmp;
}
else if constexpr(N == 16)
{
thread_buffer<float, 16> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<2>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(float),
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<3>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(float),
static_cast<index_t>(coherence));
return tmp;
}
}
else if constexpr(std::is_same<T, fp16_t>::value) // fp16
{
if constexpr(N == 1)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 4)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 8)
{
// use fp32 load to mimic fp16 load
fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
{
if constexpr(N == 1)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 2)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 4)
{
return bit_cast<rtn_type>(
llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence)));
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else // other datatype
{
auto raw_data = amd_buffer_load_impl_with_bytes<sizeof(T) * N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
return bit_cast<rtn_type>(raw_data);
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t flag = 0)
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
"wrong! not supported by buffer_load instruction");
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_load_if<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
}
else
{
buffer_load<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0)
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset);
}
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i8(bit_cast<int8_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 32)
{
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
}
else if constexpr(N == 64)
{
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 8,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(
src_thread_data.template get_as<int32x4_t>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
if constexpr(std::is_same<T, float>::value) // fp32
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32(bit_cast<float>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<fp32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data.template get_as<fp32x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(
src_thread_data.template get_as<fp32x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
}
else if constexpr(std::is_same<T, fp16_t>::value) // fp16
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(bit_cast<fp16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
#if 0
thread_buffer<fp16_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(fp16_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<fp32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(bit_cast<int16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(bit_cast<int16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<int16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(
src_thread_data.template get_as<int16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bf16_t),
static_cast<index_t>(coherence));
}
}
else
{
using r_t = thread_buffer<int8_t, sizeof(T) * N>;
amd_buffer_store_impl_with_bytes<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset);
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset,
index_t is_valid_element = 1)
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
"wrong! not supported by buffer_store instruction");
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_store_if<sizeof(type)>{}(dst_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0,
is_valid_element);
}
else
{
buffer_store<sizeof(type)>{}(dst_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(std::is_same<T, float>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(bit_cast<float>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(
src_thread_data.template get_as<float>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(float),
0);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
src_thread_data.template get_as<fp16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(fp16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
src_thread_data.template get_as<fp16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(fp16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(
src_thread_data.template get_as<int32_t>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(int32_t),
0);
}
}
}
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(std::is_same<T, double>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(bit_cast<double>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(double),
0);
llvm_amdgcn_raw_buffer_atomic_max_fp64(
src_thread_data.template get_as<double>()[number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(double),
0);
}
}
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
// oob_conditional_check : dynamic check if out-of-bound
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE thread_buffer<T, N>
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = [&]() {
if constexpr(oob_conditional_check)
return src_thread_element_valid ? 0 : 0x80000000;
else
return 0;
}();
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
else
return tmp;
#endif
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE thread_buffer<T, N>
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size,
T customized_value)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
else
return tmp;
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
index_t is_valid_element = 0)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element);
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
// ... unless people manually write zero to LDS at the proper address.
// so not support invalid_element check for now.
// buffer_load OOB still working.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
}
// buffer_store requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = [&]() {
if constexpr(oob_conditional_check)
return dst_thread_element_valid ? 0 : 0x80000000;
else
return 0;
}();
amd_buffer_store_impl<T, N, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if constexpr(oob_conditional_check)
{
if(dst_thread_element_valid)
{
amd_buffer_store_impl<T, N, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
}
else
{
amd_buffer_store_impl<T, N, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_thread_element_valid);
}
// buffer_atomic_add requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_max_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_max_impl<T, N>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
T* lds_base_ptr,
const index_t lds_offset,
const bool is_valid,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource =
make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T));
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource));
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
namespace ck_tile {
enum struct address_space_enum
{
generic,
global,
lds,
sgpr,
vgpr,
};
enum struct memory_operation_enum
{
set,
atomic_add,
atomic_max,
add
};
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
{
// warpSize is defined by HIP
return warpSize;
}
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
// TODO: deprecate these
CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
CK_TILE_DEVICE index_t get_warp_id()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
}
CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE void block_sync_lds()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#else
__syncthreads();
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\
s_waitcnt vmcnt(0) \n \
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
}
CK_TILE_DEVICE void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <stdint.h>
namespace ck_tile {
// TODO: we have "memory" clobber here because this inline asm is used for async copy
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
{
asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory");
}
// NOTE: this is an immediate value
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
{
asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory");
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_up(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
{
#if 0
return __shfl_down(v_local, lane_delta);
#elif 1
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
(__lane_id() << 2) + (lane_delta << 2), bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
#endif
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#endif
#ifdef __HIPCC__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
#endif
#define CK_TILE_FLOAT_TO_FP8_STANDARD 0
#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1
#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT
#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD
#endif
// in the old rocm period, we have to use tuple array implementation to implement this
// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default.
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
#endif
#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0
#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1
#ifndef CK_TILE_THREAD_BUFFER_DEFAULT
#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY
#endif
#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
// if using tuple-array as thread_buffer implementation, need to support {} brace init
// ... with similiar behavior as array
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1
#else
#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0
#endif
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS
#define CK_TILE_USE_LAUNCH_BOUNDS 1
#endif
#ifndef CK_TILE_TIME_KERNEL
#define CK_TILE_TIME_KERNEL 1
#endif
#define CK_TILE_MAX_THREAD_PER_BLOCK 256
#define CK_TILE_MIN_BLOCK_PER_CU 2
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_LOAD
#define CK_TILE_USE_AMD_BUFFER_LOAD 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_STORE
#define CK_TILE_USE_AMD_BUFFER_STORE 1
#endif
#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)) // for GPU code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
#endif
#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0
#endif
#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// use aggregate initialization for this type
// e.g. array<index_t, 4> buf {0}; => {0, 0, 0, 0}, clean
// array<index_t, 4> buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0})
// use make_array_with({...}) to construct an array with compatible behavior as old ck
// TODO: manually added constructor same as old ck
template <typename T_, index_t N_>
struct array
{
using value_type = T_;
static constexpr index_t N = N_;
// TODO: do we need this?
// using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type))));
// union {
value_type data[N];
// bulk_type __content;
//};
CK_TILE_HOST_DEVICE constexpr array() : data{} {}
// TODO: will initialize the data[] with the last value repeatedly
// behavior different from std
CK_TILE_HOST_DEVICE constexpr array(std::initializer_list<value_type> ilist)
{
constexpr index_t list_size = std::initializer_list<value_type>{}.size();
static_assert(list_size <= N, "out of bound");
index_t i = 0;
value_type vlast = value_type{};
for(const value_type& val : ilist)
{
data[i] = val;
vlast = val;
++i;
}
for(; i < N; ++i)
{
data[i] = vlast;
}
}
template <typename Y,
typename = std::enable_if_t<std::is_convertible_v<Y, value_type> ||
std::is_constructible_v<Y, value_type>>>
CK_TILE_HOST_DEVICE explicit constexpr array(Y c)
{
for(auto i = 0; i < size(); i++)
data[i] = static_cast<value_type>(c);
}
// template <typename Y>
// CK_TILE_HOST_DEVICE constexpr array(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// }
// CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o)
// {
// // static_assert(ArrayType::size() == size(), "wrong! size not the same");
// __content = o.__content;
// return *this;
// }
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<value_type>; }
// clang-format off
CK_TILE_HOST_DEVICE constexpr auto& get() { return data; }
CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; }
CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& get(number<I>) { return data[I]; }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get(number<I>) const { return data[I]; }
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); }
CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible
#if 0
template <typename ArrayLike>
CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr)
{
static_assert(ArrayLike::size() == size(), "wrong! size not the same");
for(index_t i = 0; i < size(); ++i)
{
data[i] = arr[i];
}
return *this;
}
#endif
// type punning (strict aliasing) member functions for read/write
// aliasing this array of type "T", "N" elements
// as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements
#define AR_AS_COM_() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as()
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as() const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data); }
// below index is for index *AFTER* type convert, not before
template <typename Tx> CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(i); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr auto& get_as(number<I>)
{ AR_AS_COM_(); return reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr const auto& get_as(number<I>) const
{ AR_AS_COM_(); return reinterpret_cast<const array<Tx, vx>&>(data).at(number<I>{}); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ AR_AS_COM_(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef AR_AS_COM_
// clang-format on
};
// empty Array
template <typename T>
struct array<T, 0>
{
using value_type = T;
CK_TILE_HOST_DEVICE constexpr array() {}
CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v<T>; };
CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); }
};
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<array<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
namespace details {
template <class>
struct is_ref_wrapper : std::false_type
{
};
template <class T>
struct is_ref_wrapper<std::reference_wrapper<T>> : std::true_type
{
};
template <class T>
using not_ref_wrapper = std::negation<is_ref_wrapper<std::decay_t<T>>>;
template <class D, class...>
struct return_type_helper
{
using type = D;
};
template <class... Ts>
struct return_type_helper<void, Ts...> : std::common_type<Ts...>
{
static_assert(std::conjunction_v<not_ref_wrapper<Ts>...>,
"Ts cannot contain reference_wrappers when D is void");
};
template <class D, class... Ts>
using return_type = array<typename return_type_helper<D, Ts...>::type, sizeof...(Ts)>;
} // namespace details
template <typename D = void, typename... Ts>
CK_TILE_HOST_DEVICE constexpr details::return_type<D, Ts...> make_array(Ts&&... ts)
{
return {std::forward<Ts>(ts)...};
}
// // make empty array
// template <typename T>
// CK_TILE_HOST_DEVICE constexpr auto make_array()
// {
// return array<T, 0>{};
// }
// compatible with old ck's initializer, make an array and fill it withe the last element from
// initializer_list
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list<T> ilist)
{
return array<T, Size>(ilist);
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator==(const array<T, Size>& a, const array<T, Size>& b)
{
bool same = true;
for(index_t i = 0; i < Size; ++i)
{
if(a[i] != b[i])
{
same = false;
break;
}
}
return same;
}
template <typename T, index_t Size>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const array<T, Size>& b)
{
return !(a == b);
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{
static_assert(N <= X::size(), "");
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array<TData, NSize>& a, const TData& x)
{
array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r[number<NSize>{}] = x;
return r;
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple<Ts...>& a, const T& x)
{
return container_concat(make_tuple(x), a);
}
template <typename... Ts, typename T>
CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple<Ts...>& a, const T& x)
{
return container_concat(a, make_tuple(x));
}
// reorder array
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
}
template <typename TData, index_t NSize, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array, sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder array
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_new2old(const array<TData, NSize>& old_array,
const map<index_t, index_t>& new2old)
{
array<TData, NSize> new_array;
for(const auto& [new_pos, old_pos] : new2old)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
template <typename TData, index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto
container_reorder_given_old2new(const array<TData, NSize>& old_array,
const map<index_t, index_t>& old2new)
{
array<TData, NSize> new_array;
for(const auto& [old_pos, new_pos] : old2new)
{
new_array(new_pos) = old_array[old_pos];
}
return new_array;
}
// reorder tuple
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<Ts...>& old_tuple,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_tuple(old_tuple[number<IRs>{}]...);
}
template <typename... Ts, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple<Ts...>& old_tuple,
sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
// reorder sequence
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is...> /* old_seq */,
sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
}
template <index_t... Is, index_t... IRs>
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is...> old_seq,
sequence<IRs...> /* old2new */)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
return container_reorder_given_new2old(old_seq, new2old);
}
#if 0
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename Container,
typename Reduce,
typename ROld,
index_t I,
index_t IEnd,
index_t IStep>
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(
const Container& x, Reduce reduce, ROld r_old, number<I> i, number<IEnd>, number<IStep>)
{
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
return container_reduce_impl(
x, reduce, r_new, i + number<IStep>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return r_new;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::size(),
index_t IStep = 1>
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
number<IBegin> = number<0>{},
number<IEnd> = number<Container::size()>{},
number<IStep> = number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
if constexpr(IEnd > IBegin)
{
return container_reduce_impl(
x, reduce, init, number<IBegin>{}, number<IEnd>{}, number<IStep>{});
}
else
{
return init;
}
}
#endif
template <typename TData, index_t NSize, typename Reduce>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const array<TData, NSize>& x, Reduce f, TData init)
{
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename TData, index_t NSize, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const array<TData, NSize>& x, Reduce f, Init init)
{
#if 0
array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(number<0>{}) = r;
return y;
#else
array<TData, NSize> y;
TData r = init;
for(index_t i = NSize - 1; i > 0; --i)
{
y(i) = r;
r = f(r, x[i]);
}
y(0) = r;
return y;
#endif
}
template <index_t... Is, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const sequence<Is...>& seq, Reduce f, number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, number<Init>{});
}
#if 0
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(
const tuple<Xs...>& x, Reduce reduce, number<I> i, YOld y_old, ROld r_old)
{
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
}
template <typename... Xs, typename Reduce, typename Init>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_exclusive_scan(const tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
return container_reverse_exclusive_scan_impl(
x, reduce, number<NSize - 1>{}, make_tuple(init), init);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData>
CK_TILE_HOST_DEVICE constexpr auto
container_reverse_inclusive_scan(const tuple<Xs...>& x, Reduce f, TData init)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[number<0>{}]);
y(number<0>{}) = r;
return y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys)
{
return container_concat(x, container_concat(ys...));
}
template <typename T, index_t NX, index_t NY>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const array<T, NX>& ax, const array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array<T>(std::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x)
{
return x;
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array<T, N>& arr, sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_array<T>(arr[Is]...);
}
else
{
return array<T, 0>{};
}
}
template <typename... Ts, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple<Ts...>& tup, sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
return make_tuple(tup[number<Is>{}]...);
}
else
{
return tuple<>{};
}
}
template <typename T, index_t N, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void
set_container_subset(array<T, N>& y, sequence<Is...> picks, const array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
for(index_t i = 0; i < picks.size(); ++i)
{
y(picks[i]) = x[i];
}
}
}
template <typename Y, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence<Is...> picks, const X& x)
{
static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size");
if constexpr(sizeof...(Is) > 0)
{
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
}
// return the index of first occurance in the sequence.
// return seq.size(), if not found
template <index_t... Is>
constexpr index_t container_find(sequence<Is...> seq, index_t value)
{
for(auto i = 0; i < seq.size(); i++)
{
if(seq[i] == value)
return i;
}
return seq.size();
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence<Is...>)
{
using Seq = sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::at(i);
return number<tmp>{};
},
number<Seq::size()>{});
}
#if 0
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, a_size, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#else
// constexpr index_t can't be captured "-Wunused-lambda-capture"
// TODO: this is ugly
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \
[a_of_b_impl, bs_sizes] { \
return ck_tile::generate_tuple( \
[=](auto i) { \
constexpr auto b_impl = a_of_b_impl[i]; \
constexpr index_t b_size = bs_sizes[i]; \
constexpr auto b = TO_SEQUENCE(b_impl, b_size); \
return b; \
}, \
ck_tile::number<a_size>{}); \
}()
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
// naive map
template <typename key, typename data, index_t max_size = 128>
struct map
{
using pair_type = tuple<key, data>;
using impl_type = array<pair_type, max_size>;
impl_type impl_;
index_t size_;
struct iterator
{
impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); }
};
struct const_iterator
{
const impl_type& impl_;
index_t pos_;
CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos)
: impl_{impl}, pos_{pos}
{
}
CK_TILE_HOST_DEVICE constexpr const_iterator& operator++()
{
pos_++;
return *this;
}
CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const
{
return other.pos_ != pos_;
}
CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); }
};
CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {}
CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; }
CK_TILE_HOST_DEVICE void clear() { size_ = 0; }
CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const
{
for(index_t i = 0; i < size(); i++)
{
if(impl_[i].template at<0>() == k)
{
return i;
}
}
return size_;
}
CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const
{
return const_iterator{impl_, find_position(k)};
}
CK_TILE_HOST_DEVICE constexpr iterator find(const key& k)
{
return iterator{impl_, find_position(k)};
}
CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const
{
const auto it = find(k);
// FIXME
// assert(it.pos_ < size());
return impl_[it.pos_].template at<1>();
}
CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k)
{
auto it = find(k);
// if entry not found
if(it.pos_ == size())
{
impl_(it.pos_).template at<0>() = k;
size_++;
}
// FIXME
// assert(size_ <= max_size);
return impl_(it.pos_).template at<1>();
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr const_iterator end() const
{
return const_iterator{impl_, size_};
}
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; }
// WARNING: needed by compiler for C++ range-based for loop only, don't use this function!
CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; }
CK_TILE_HOST_DEVICE void print() const
{
printf("map{size_: %d, ", size_);
//
printf("impl_: [");
//
for(const auto& [k, d] : *this)
{
printf("{key: ");
print(k);
printf(", data: ");
print(d);
printf("}, ");
}
//
printf("]");
//
printf("}");
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <cstddef>
namespace ck_tile {
// TODO: this structure is not intented to be used by user
template <index_t MaxSize>
struct meta_data_buffer
{
CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs)
: buffer_{}, size_{0}
{
push(x, xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr void push(const T& data)
{
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++)
{
buffer_(size_) = tmp[i];
size_++;
}
}
}
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs)
{
push(x);
push(xs...);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const
{
T data;
if constexpr(!std::is_empty_v<T>)
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
data = bit_cast<T>(tmp);
}
return data;
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const
{
constexpr index_t size = sizeof(T);
array<std::byte, size> tmp;
for(int i = 0; i < size; i++)
{
tmp(i) = buffer_[pos];
pos++;
}
auto data = bit_cast<T>(tmp);
return data;
}
//
array<std::byte, MaxSize> buffer_;
index_t size_ = 0;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
// Don't use tihs directly. This is for old CK's internal usage,
// in the future always use array instead
template <index_t N>
using multi_index = array<index_t, N>;
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <index_t NSize, typename X>
CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index<NSize>& y, const X& x)
{
static_assert(X::size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& a, const T& b)
{
using type = multi_index<NSize>;
static_assert(T::size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; });
return r;
}
// multi_index = index_t * multi_index
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index<NSize>& x)
{
multi_index<NSize> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// multi_index = multi_index * index_t
template <index_t NSize>
CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index<NSize>& x, index_t a)
{
return a * x;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
template <typename Seq, index_t I>
struct sequence_split;
template <typename>
struct sequence_reverse;
template <typename>
struct sequence_map_inverse;
template <typename>
struct is_valid_sequence_map;
template <index_t I, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence<I, Is...>);
template <typename Seq>
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq);
namespace impl {
// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element");
template <index_t I, typename... Ts>
using at_index_t = __type_pack_element<I, Ts...>;
} // namespace impl
// we could implement as below, similiar to std. But let's reduce the symbol name...
// template< class T, T... Ints >
// class integer_sequence;
template <index_t... Is>
struct sequence
{
using type = sequence;
using value_type = index_t;
CK_TILE_HOST_DEVICE static constexpr index_t size() { return sizeof...(Is); }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; };
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto get()
{
static_assert(I < size(), "wrong! I too large");
return number<impl::at_index_t<I, constant<Is>...>{}>{};
}
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto get(number<I>)
{
static_assert(I < size(), "wrong! I too large");
return number<get<I>()>{};
}
CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[size() + 1] = {Is..., 0};
return mData[I];
}
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto at()
{
static_assert(I < size(), "wrong! I too large");
return number<impl::at_index_t<I, constant<Is>...>{}>{};
}
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto at(number<I>)
{
static_assert(I < size(), "wrong! I too large");
return number<get<I>()>{};
}
template <typename I>
CK_TILE_HOST_DEVICE constexpr auto operator[](I i) const
{
return at(i);
}
template <index_t... IRs>
CK_TILE_HOST_DEVICE static constexpr auto reorder_new_to_old(sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs),
"wrong! reorder map should have the same size as sequence to be rerodered");
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
return sequence<type::get(number<IRs>{})...>{};
}
// MapOld2New is sequence<...>
template <typename MapOld2New>
CK_TILE_HOST_DEVICE static constexpr auto reorder_old_to_new(MapOld2New)
{
static_assert(MapOld2New::size() == size(),
"wrong! reorder map should have the same size as sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
return reorder_new_to_old(typename sequence_map_inverse<MapOld2New>::type{});
}
CK_TILE_HOST_DEVICE static constexpr auto reverse()
{
return typename sequence_reverse<type>::type{};
}
CK_TILE_HOST_DEVICE static constexpr auto front()
{
static_assert(size() > 0, "wrong!");
return get(number<0>{});
}
CK_TILE_HOST_DEVICE static constexpr auto back()
{
static_assert(size() > 0, "wrong!");
return get(number<size() - 1>{});
}
CK_TILE_HOST_DEVICE static constexpr auto pop_front() { return sequence_pop_front(type{}); }
CK_TILE_HOST_DEVICE static constexpr auto pop_back() { return sequence_pop_back(type{}); }
template <index_t... Xs>
CK_TILE_HOST_DEVICE static constexpr auto push_front(sequence<Xs...>)
{
return sequence<Xs..., Is...>{};
}
template <index_t... Xs>
CK_TILE_HOST_DEVICE static constexpr auto push_front(number<Xs>...)
{
return sequence<Xs..., Is...>{};
}
template <index_t... Xs>
CK_TILE_HOST_DEVICE static constexpr auto push_back(sequence<Xs...>)
{
return sequence<Is..., Xs...>{};
}
template <index_t... Xs>
CK_TILE_HOST_DEVICE static constexpr auto push_back(number<Xs>...)
{
return sequence<Is..., Xs...>{};
}
// pickup element at index <Ids...>
template <index_t... Ids>
CK_TILE_HOST_DEVICE static constexpr auto extract(number<Ids>...)
{
return sequence<type::get(number<Ids>{})...>{};
}
template <index_t... Ids>
CK_TILE_HOST_DEVICE static constexpr auto extract(sequence<Ids...>)
{
return sequence<type::get(number<Ids>{})...>{};
}
// modify element at index "I" with value "X"
template <index_t I, index_t X>
CK_TILE_HOST_DEVICE static constexpr auto modify(number<I>, number<X>)
{
static_assert(I < size(), "wrong!");
using seq_split = sequence_split<type, I>;
constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::right_type{}.pop_front();
return seq_left.push_back(number<X>{}).push_back(seq_right);
}
template <typename F>
CK_TILE_HOST_DEVICE static constexpr auto transform(F f)
{
return sequence<f(Is)...>{};
}
CK_TILE_HOST_DEVICE static void print()
{
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
};
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;
template <index_t... Ints>
struct __integer_sequence<index_t, Ints...>
{
using seq_type = sequence<Ints...>;
};
} // namespace impl
// similiar
template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
template <index_t... Xs, index_t... Ys>
struct sequence_merge<sequence<Xs...>, sequence<Ys...>>
{
using type = sequence<Xs..., Ys...>;
};
template <typename Seq>
struct sequence_merge<Seq>
{
using type = Seq;
};
// generate sequence
template <index_t NSize, typename F>
struct sequence_gen
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(number<I>{});
using type = sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = sequence<>;
};
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
// arithmetic sequence
template <index_t IBegin, index_t IEnd, index_t Increment>
struct arithmetic_sequence_gen
{
struct F
{
CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t i) const
{
return i * Increment + IBegin;
}
};
using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
using type1 = sequence<>;
static constexpr bool kHasContent =
(Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd);
using type = typename std::conditional<kHasContent, type0, type1>::type;
};
template <index_t IEnd>
struct arithmetic_sequence_gen<0, IEnd, 1>
{
using type = make_index_sequence<IEnd>;
};
// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t) const { return I; }
};
using type = typename sequence_gen<NSize, F>::type;
};
// reverse inclusive scan (with init) sequence
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<sequence<I, Is...>, Reduce, Init>
{
using old_scan = typename sequence_reverse_inclusive_scan<sequence<Is...>, Reduce, Init>::type;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.front());
using type = typename sequence_merge<sequence<new_reduce>, old_scan>::type;
};
template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<sequence<I>, Reduce, Init>
{
using type = sequence<Reduce{}(I, Init)>;
};
template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<sequence<>, Reduce, Init>
{
using type = sequence<>;
};
// split sequence
template <typename Seq, index_t I>
struct sequence_split
{
static constexpr index_t NSize = Seq{}.size();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using left_type = decltype(Seq::extract(range0{}));
using right_type = decltype(Seq::extract(range1{}));
};
#if 0
// reverse sequence
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
struct sequence_reverse<sequence<I>>
{
using type = sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<sequence<I0, I1>>
{
using type = sequence<I1, I0>;
};
#endif
namespace impl {
template <typename Id, index_t... Ns>
struct seq_reverse;
template <index_t... Ids, index_t... Ns>
struct seq_reverse<sequence<Ids...>, Ns...>
{
template <index_t I>
using element = impl::at_index_t<I, constant<Ns>...>;
using type = sequence<element<(sizeof...(Ns) - 1 - Ids)>::value...>;
};
} // namespace impl
template <index_t... Ns>
struct sequence_reverse<sequence<Ns...>>
: impl::seq_reverse<make_index_sequence<sizeof...(Ns)>, Ns...>
{
};
// template <index_t... Ns>
// using sequence_reverse_t = typename sequence_reverse<Ns...>::type;
#if 1
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
using type = typename sequence_reduce<Reduce,
Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, sequence<Xs...>, sequence<Ys...>>
{
using type = sequence<Reduce{}(Xs, Ys)...>;
};
template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
using type = Seq;
};
#endif
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
{
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool choose_left = LeftValues::front() < RightValues::front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::front() : RightValues::front();
static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front();
using new_merged_values = decltype(MergedValues::push_back(number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::push_back(number<chosen_id>{}));
using new_left_values = typename std::
conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
using new_left_ids =
typename std::conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
using new_right_values = typename std::
conditional<choose_left, RightValues, decltype(RightValues::pop_front())>::type;
using new_right_ids =
typename std::conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
sequence<>,
sequence<>,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<sequence<>,
sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
sequence<>,
sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<sequence<ValueX, ValueY>, sequence<IdX, IdY>, Compare>
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values = typename std::
conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
using sorted_ids =
typename std::conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<sequence<Value>, sequence<Id>, Compare>
{
using sorted_values = sequence<Value>;
using sorted_ids = sequence<Id>;
};
template <typename Compare>
struct sequence_sort_impl<sequence<>, sequence<>, Compare>
{
using sorted_values = sequence<>;
using sorted_ids = sequence<>;
};
template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
};
template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort
{
template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t current_value = RemainValues::front();
static constexpr index_t current_id = RemainIds::front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::back());
using new_remain_values = decltype(RemainValues::pop_front());
using new_remain_ids = decltype(RemainIds::pop_front());
using new_uniquified_values =
typename std::conditional<is_unique_value,
decltype(UniquifiedValues::push_back(
number<current_value>{})),
UniquifiedValues>::type;
using new_uniquified_ids =
typename std::conditional<is_unique_value,
decltype(UniquifiedIds::push_back(number<current_id>{})),
UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<sequence<>,
sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{
using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
};
template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify
{
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::pop_front()),
decltype(SortedIds::pop_front()),
sequence<SortedValues::front()>,
sequence<SortedIds::front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
template <typename SeqMap>
struct is_valid_sequence_map
: std::is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
typename sequence_sort<SeqMap, less<index_t>>::type>
{
};
template <typename SeqMap>
struct sequence_map_inverse
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::modify(X2Y::get(number<XBegin>{}), number<XBegin>{});
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::size(), 0>::type,
0,
SeqMap::size()>::type;
};
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr bool operator==(sequence<Xs...>, sequence<Ys...>)
{
return ((Xs == Ys) && ...);
}
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr bool operator!=(sequence<Xs...> x, sequence<Ys...> y)
{
return !(x == y);
}
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(sequence<Xs...>, sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return sequence<(Xs + Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(sequence<Xs...>, sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return sequence<(Xs - Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(sequence<Xs...>, sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return sequence<(Xs * Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(sequence<Xs...>, sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return sequence<(Xs / Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator%(sequence<Xs...>, sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return sequence<(Xs % Ys)...>{};
}
template <index_t... Xs, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto operator+(sequence<Xs...>, number<Y>)
{
return sequence<(Xs + Y)...>{};
}
template <index_t... Xs, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto operator-(sequence<Xs...>, number<Y>)
{
return sequence<(Xs - Y)...>{};
}
template <index_t... Xs, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto operator*(sequence<Xs...>, number<Y>)
{
return sequence<(Xs * Y)...>{};
}
template <index_t... Xs, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto operator/(sequence<Xs...>, number<Y>)
{
return sequence<(Xs / Y)...>{};
}
template <index_t... Xs, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto operator%(sequence<Xs...>, number<Y>)
{
return sequence<(Xs % Y)...>{};
}
template <index_t Y, index_t... Xs>
CK_TILE_HOST_DEVICE constexpr auto operator+(number<Y>, sequence<Xs...>)
{
return sequence<(Y + Xs)...>{};
}
template <index_t Y, index_t... Xs>
CK_TILE_HOST_DEVICE constexpr auto operator-(number<Y>, sequence<Xs...>)
{
return sequence<(Y - Xs)...>{};
}
template <index_t Y, index_t... Xs>
CK_TILE_HOST_DEVICE constexpr auto operator*(number<Y>, sequence<Xs...>)
{
return sequence<(Y * Xs)...>{};
}
template <index_t Y, index_t... Xs>
CK_TILE_HOST_DEVICE constexpr auto operator/(number<Y>, sequence<Xs...>)
{
return sequence<(Y / Xs)...>{};
}
template <index_t Y, index_t... Xs>
CK_TILE_HOST_DEVICE constexpr auto operator%(number<Y>, sequence<Xs...>)
{
return sequence<(Y % Xs)...>{};
}
template <index_t I, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence<I, Is...>)
{
return sequence<Is...>{};
}
template <typename Seq>
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq)
{
static_assert(Seq::size() > 0, "wrong! cannot pop an empty sequence!");
return sequence_pop_front(Seq::reverse()).reverse();
}
template <typename... Seqs>
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
{
return typename sequence_merge<Seqs...>::type{};
}
template <typename F, index_t... Xs>
CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence<Xs...>)
{
return sequence<f(Xs)...>{};
}
template <typename F, index_t... Xs, index_t... Ys>
CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence<Xs...>, sequence<Ys...>)
{
static_assert(sequence<Xs...>::size() == sequence<Ys...>::size(), "Dim not the same");
return sequence<f(Xs, Ys)...>{};
}
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
CK_TILE_HOST_DEVICE constexpr auto
transform_sequences(F f, sequence<Xs...>, sequence<Ys...>, sequence<Zs...>)
{
static_assert(sequence<Xs...>::size() == sequence<Ys...>::size() &&
sequence<Xs...>::size() == sequence<Zs...>::size(),
"Dim not the same");
return sequence<f(Xs, Ys, Zs)...>{};
}
template <typename Seq, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, number<Init>)
{
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
}
template <typename Seq, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, number<Init>)
{
return reverse_inclusive_scan_sequence(Seq::pop_front(), Reduce{}, number<Init>{})
.push_back(number<Init>{});
}
template <typename Seq, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number<Init>)
{
return reverse_inclusive_scan_sequence(Seq{}.reverse(), Reduce{}, number<Init>{}).reverse();
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add
// ResultSeq TargetSeq Reduce
template <typename, typename, typename>
struct sequence_exclusive_scan;
template <index_t... Xs, index_t Y, index_t... Ys, typename Reduce>
struct sequence_exclusive_scan<sequence<Xs...>, sequence<Y, Ys...>, Reduce>
{
using old_scan = typename sequence_merge<sequence<Xs...>,
sequence<Reduce{}(Y, sequence<Xs...>{}.back())>>::type;
using type = typename sequence_exclusive_scan<old_scan, sequence<Ys...>, Reduce>::type;
};
template <index_t... Xs, index_t Y, typename Reduce>
struct sequence_exclusive_scan<sequence<Xs...>, sequence<Y>, Reduce>
{
using type = sequence<Xs...>;
};
template <index_t... Xs, typename Reduce>
struct sequence_exclusive_scan<sequence<Xs...>, sequence<>, Reduce>
{
using type = sequence<Xs...>;
};
template <typename Seq, typename Reduce, index_t Init>
constexpr auto exclusive_scan_sequence(Seq, Reduce, number<Init>)
{
// TODO: c++20 and later can pass in Reduce with a lambda expression
return typename sequence_exclusive_scan<sequence<Init>, Seq, Reduce>::type{};
}
template <typename Seq>
constexpr auto prefix_sum_sequence(Seq)
{
return typename sequence_exclusive_scan<sequence<0>,
typename sequence_merge<Seq, sequence<0>>::type,
plus<index_t>>::type{};
}
template <typename Seq, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_ids(Seq, sequence<Is...> /* ids */)
{
return sequence<Seq::get(number<Is>{})...>{};
}
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
using new_work_seq = typename std::conditional<RemainMask::front(),
decltype(WorkSeq::push_back(RemainSeq::front())),
WorkSeq>::type;
using type =
typename pick_sequence_elements_by_mask_impl<new_work_seq,
decltype(RemainSeq::pop_front()),
decltype(RemainMask::pop_front())>::type;
};
template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, sequence<>, sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
template <typename Seq, typename Mask>
CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
static_assert(Seq::size() == Mask::size(), "wrong!");
return typename detail::pick_sequence_elements_by_mask_impl<sequence<>, Seq, Mask>::type{};
}
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
{
using new_work_seq = decltype(WorkSeq::modify(RemainIds::front(), RemainValues::front()));
using type =
typename modify_sequence_elements_by_ids_impl<new_work_seq,
decltype(RemainValues::pop_front()),
decltype(RemainIds::pop_front())>::type;
};
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, sequence<>, sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
template <typename Seq, typename Values, typename Ids>
CK_TILE_HOST_DEVICE constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
static_assert(Values::size() == Ids::size() && Seq::size() >= Values::size(), "wrong!");
return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
template <typename Seq, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr index_t
reduce_on_sequence(Seq, Reduce f, number<Init> /*initial_value*/)
{
index_t result = Init;
for(index_t i = 0; i < Seq::size(); ++i)
{
result = f(result, Seq::at(i));
}
return result;
}
// TODO: a generic any_of for any container
template <typename Seq, typename F>
CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f)
{
bool flag = false;
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag || f(Seq::at(i));
}
return flag;
}
// TODO: a generic all_of for any container
template <typename Seq, typename F>
CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f)
{
bool flag = true;
for(index_t i = 0; i < Seq::size(); ++i)
{
flag = flag && f(Seq::at(i));
}
return flag;
}
template <typename... Seqs>
using sequence_merge_t = typename sequence_merge<Seqs...>::type;
template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_sequence(number<Is>...)
{
return sequence<Is...>{};
}
// F() returns index_t
// F use default constructor, so F cannot be lambda function
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_sequence(F, number<N>)
{
return typename sequence_gen<N, F>::type{};
}
// F() returns number<>
// F could be lambda function
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number<N>)
{
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <class... T>
struct tuple;
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
{
return sequence<Is...>{};
}
namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>
struct sorted_sequence_histogram;
template <index_t h_idx, index_t x, index_t... xs, index_t r, index_t... rs>
struct sorted_sequence_histogram<h_idx, sequence<x, xs...>, sequence<r, rs...>>
{
template <typename Histogram>
constexpr auto operator()(Histogram& h)
{
if constexpr(x < r)
{
h.template at<h_idx>() += 1;
sorted_sequence_histogram<h_idx, sequence<xs...>, sequence<r, rs...>>{}(h);
}
else
{
h.template at<h_idx + 1>() = 1;
sorted_sequence_histogram<h_idx + 1, sequence<xs...>, sequence<rs...>>{}(h);
}
}
};
template <index_t h_idx, index_t x, index_t r, index_t... rs>
struct sorted_sequence_histogram<h_idx, sequence<x>, sequence<r, rs...>>
{
template <typename Histogram>
constexpr auto operator()(Histogram& h)
{
if constexpr(x < r)
{
h.template at<h_idx>() += 1;
}
}
};
} // namespace detail
template <typename, index_t>
struct array; // declare for later use (array->seq utility)
// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1>
template <typename SeqSortedSamples, index_t r, index_t... rs>
CK_TILE_HOST_DEVICE constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence<r, rs...>)
{
constexpr auto bins = sizeof...(rs); // or categories
constexpr auto histogram = [&]() {
array<index_t, bins> h{0}; // make sure this can clear all element to zero
detail::sorted_sequence_histogram<0, SeqSortedSamples, sequence<rs...>>{}(h);
return h;
}();
return TO_SEQUENCE(histogram, bins);
}
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
{
using T = remove_cvref_t<decltype(f(number<0>{}))>;
return unpack([&f](auto&&... is) { return array<T, N>{f(is)...}; },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <cstddef>
#include <array>
#include <type_traits>
namespace ck_tile {
// implement the c++20 std::span, lightweight, non-owning reference to a sequence
// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence
// TODO: do we need in device consider this is pointer?
template <typename T>
class span
{
public:
using element_type = T;
using value_type = std::remove_cv_t<element_type>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = element_type*;
using const_pointer = const element_type*;
using reference = element_type&;
using const_reference = const element_type&;
using iterator = pointer;
using const_iterator = pointer;
CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {}
CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count)
{
}
CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
{
}
template <std::size_t N>
CK_TILE_HOST_DEVICE constexpr span(std::array<value_type, N>& arr) noexcept
: span(arr.data(), N)
{
}
template <typename Container>
CK_TILE_HOST_DEVICE constexpr span(const Container& container)
: span(container.data(), container.size())
{
}
CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); }
CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); }
CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); }
CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); }
CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); }
CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const
{
return *(begin() + idx);
}
CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; }
CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; }
private:
pointer ptr_;
size_type size_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
template <typename T, index_t N>
using statically_indexed_array = tuple_array<T, N>;
#else
// consider mark this struct as deprecated
template <typename T, index_t N>
using statically_indexed_array = array<T, N>;
#endif
// consider always use ck_tile::array for this purpose
#if 0
template <typename X, typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return statically_indexed_array<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty statically_indexed_array
template <typename X>
CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array()
{
return statically_indexed_array<X, 0>();
}
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/tuple.hpp"
namespace ck_tile {
#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
template <typename T, index_t N>
using thread_buffer = tuple_array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_tuple(ts...);
}
#else
#if 0
template <typename T, index_t N>
using thread_buffer = array<T, N>;
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
// clang-format off
template<typename T_, index_t N_>
struct thread_buffer {
using value_type = remove_cvref_t<T_>;
static constexpr index_t N = N_;
value_type data[N];
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE auto & get() {return data; }
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
static_assert(N % kSPerX == 0);
union {
thread_buffer<X_, N / kSPerX> data {};
// tuple_array<value_type, kSPerX> sub_data;
value_type sub_data[N];
} vx;
static_for<0, N, 1>{}(
[&](auto j) { vx.sub_data[j] = data[j]; });
return vx.data;
}
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data {};
tuple_array<value_type, kSPerX> sub_data;
} vx;
static_for<0, kSPerX, 1>{}(
[&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
return vx.data;
}
#if 0
template <typename X_,
index_t Is,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
{
using X = remove_cvref_t<X_>;
constexpr index_t kSPerX = vector_traits<X>::vector_size;
union {
X_ data;
tuple_array<value_type, kSPerX> sub_data;
} vx {x};
static_for<0, kSPerX, 1>{}(
[&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
}
#endif
#define TB_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template<typename Tx>
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
template<typename Tx>
CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
if constexpr(sizeof(value_type) <= 1 )
return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
template<typename Tx, index_t I>
CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
template<typename Tx, index_t I>
CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
if constexpr(sizeof(value_type) <= 1 )
return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
else
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef TB_COMMON_AS
};
// clang-format on
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <utility>
#include <initializer_list>
#ifndef CK_TILE_TUPLE_IMPL
#define CK_TILE_TUPLE_IMPL 1
#endif
namespace ck_tile {
namespace impl {
template <typename T, index_t N>
struct tuple_array_impl;
}
template <typename T, index_t N>
using tuple_array = typename impl::tuple_array_impl<T, N>::type;
namespace impl {
// the place where content is stored
template <index_t idx, typename T, bool is_empty = std::is_empty_v<T>>
struct tuple_object
{
};
template <index_t idx, typename T>
struct tuple_object<idx, T, true>
{
CK_TILE_HOST_DEVICE constexpr tuple_object() {}
#if CK_TILE_TUPLE_IMPL == 0
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(const U&)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <typename U,
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&&)
{
}
#endif
};
template <index_t idx, typename T>
struct tuple_object<idx, T, false>
{
CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {}
#if CK_TILE_TUPLE_IMPL == 0
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e)
{
}
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e)
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <typename U,
typename std::enable_if<!std::is_same<remove_cvref_t<U>, tuple_object>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward<U>(e))
{
}
#endif
T element;
};
// NOTE: we return a instance(not a reference) if content is empty
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object<I, T, true>&)
{
return {};
}
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object<I, T, false>& x)
{
return x.element;
}
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object<I, T, false>& x)
{
return x.element;
}
template <index_t I, class T>
CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object<I, T, false>&& x)
{
return static_cast<T&&>(x.element);
}
template <typename index_seq, typename... T>
struct tuple_base;
template <index_t... I, typename... T>
struct tuple_base<sequence<I...>, T...> : tuple_object<I, T>...
{
CK_TILE_HOST_DEVICE constexpr tuple_base() = default;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
#define _ILE() (std::initializer_list<U>{}.size() - 1)
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list<U> us)
: tuple_object<I, T>(static_cast<T>(*(us.begin() + (I >= _ILE() ? _ILE() : I))))...
{
}
#undef _ILE
#endif
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u)
: tuple_object<I, T>(std::forward<U>(u))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object<I, T>(u)...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>&& u)
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&&>(u)))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<const tuple_object<I, U>&>(u)))...
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base<sequence<I...>, U...>& u)
: tuple_object<I, T>(getv(static_cast<tuple_object<I, U>&>(u)))...
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <class U,
typename std::enable_if<sizeof...(I) == 1 && sizeof...(T) == 1 &&
!std::is_same<remove_cvref_t<U>, tuple_base>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object<I, T>(std::forward<U>(u))...
{
}
template <typename... U, typename std::enable_if<sizeof...(U) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object<I, T>(std::forward<U>(u))...
{
static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U),
"wrong! inconsistent size");
}
#endif
};
} // namespace impl
template <class... T>
struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
{
CK_TILE_HOST_DEVICE
static constexpr auto size() { return sizeof...(T); }
using base = impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>;
CK_TILE_HOST_DEVICE constexpr tuple() = default;
#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST
template <typename U>
CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list<U> us) : base(us)
{
}
#endif
#if CK_TILE_TUPLE_IMPL == 0
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...)
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>&& u)
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&&>(u))
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(const tuple<U...>& u)
: base(static_cast<const impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
template <class... U>
CK_TILE_HOST_DEVICE constexpr tuple(tuple<U...>& u)
: base(static_cast<impl::tuple_base<make_index_sequence<sizeof...(U)>, U...>&>(u))
{
}
#elif CK_TILE_TUPLE_IMPL == 1
template <
typename U,
typename std::enable_if<sizeof...(T) == 1 && !std::is_same<remove_cvref_t<U>, tuple>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward<U>(u))
{
}
template <typename... U,
typename std::enable_if<sizeof...(U) == sizeof...(T) && sizeof...(U) >= 2,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward<U>(u)...)
{
}
#endif
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool flag = true;
static_for<0, sizeof...(T), 1>{}([&flag](auto i) {
flag &= is_static_v<remove_cvref_t<__type_pack_element<i.value, T...>>>;
});
return flag;
}
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number<I>) { TP_COM_(); return get<I>(); } // TODO: compatible
// below function should be used under tuple_array<> type, no extra check will perform here
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast<tuple_array<Tx, size()>&>(*this); }
template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast<const tuple_array<Tx, size()>&>(*this); }
// below index is for index *AFTER* type convert, not before
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i); }
//template <typename Tx> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(i); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) { TP_COM_(); return reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number<I>) const { TP_COM_(); return reinterpret_cast<const tuple_array<Tx, size()>&>(*this).at(number<I>{}); }
// template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x) { TP_COM_(); reinterpret_cast<tuple_array<Tx, size()>&>(*this).at(number<I>{}) = x; }
// clang-format on
#undef TP_COM_
};
template <typename>
struct vector_traits;
// specialization for array
template <typename... T>
struct vector_traits<tuple<T...>>
{
using scalar_type = __type_pack_element<0, T...>;
static constexpr index_t vector_size = sizeof...(T);
};
// template <class... T>
// CK_TILE_HOST_DEVICE constexpr
// tuple<T...>
// make_tuple(T const&... t)
// {
// return {t...};
// }
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
bool same = true;
static_for<0, sizeof...(Xs), 1>{}([&](auto i) {
if(a[i] != b[i])
{
same = false;
}
});
return same;
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple<Xs...>& a, const tuple<Xs...>& b)
{
return !(a == b);
}
template <typename... Xs>
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs)
{
// here xs is always a lvalue as function arg
// Xs may deduced as (e.g try to pass in a integer in following cases)
// 1). if pass in a rvalue (like function return or int{}) -> Xs is "int"
// 2). if pass in a const lvalue -> Xs is "const int &"
// 3). if pass in a non-const lvalue -> Xs is "int &"
// so the return type of std::forward will dependes on Xs
// 1). std::forward -> int&&
// 2). std::forward -> const int&
// 3). std::forward -> int&
return tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
template <typename... Args>
constexpr tuple<Args&...> tie(Args&... args) noexcept
{
return {args...};
}
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<tuple<Xs...>, tuple<Ys...>>
{
using type = tuple<Xs..., Ys...>;
};
namespace impl {
// be very careful using this type (because we want the internal type)
// template deduction will fail if infering the inner type
// e.g.
// template<typename T, index_t N> using some_wrapper = typename tuple_array_impl<T, N>::type;
// template<typename T, index_t N> void foo(const some_wrapper<T, N>&) {}
// -> compiler will fail to deduce this type, because this is under non-deduced context
// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced
// contexts")
//
// -> use this instead
// template<typename Tup> void foo(const Tup&) {}
template <typename T, index_t N>
struct tuple_array_impl
{
using type = typename tuple_concat<typename tuple_array_impl<T, N / 2>::type,
typename tuple_array_impl<T, N - N / 2>::type>::type;
};
template <typename T>
struct tuple_array_impl<T, 0>
{
using type = tuple<>;
};
template <typename T>
struct tuple_array_impl<T, 1>
{
using type = tuple<T>;
};
} // namespace impl
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return make_tuple(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number<N>)
{
return unpack([&f](auto&&... is) { return tie(f(is)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple<X&...>& tx,
const tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Y>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
{
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
{
return t;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t)
{
return make_tuple(t);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<Ts...>& t)
{
if constexpr(Depth == MaxDepth)
{
return t;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(unroll_nested_tuple<Depth + 1, MaxDepth>(ts)...);
},
t);
}
}
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
{
return generate_tuple(
[&](auto i) {
using Idx = number<tuple<Ts...>::size() - i - 1>;
return t.at(Idx{});
},
number<tuple<Ts...>::size()()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple<Ts...>& t)
{
static_assert(Idx < End, "Wrong parameters for tuple_reduce");
if constexpr(Idx + 1 == End)
{
return t.at(number<Idx>{});
}
else
{
return f(t.at(number<Idx>{}), tuple_reduce<Idx + 1, End>(f, t));
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
template <index_t depth = 0, typename T>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple<Ts...>&)
{
return max(tuple_depth<depth + 1>(Ts{})...);
}
template <typename... Seqs>
CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple<Seqs...> t_of_s)
{
constexpr index_t n0 = sizeof...(Seqs);
constexpr index_t max_n1 = [&] {
index_t max_n1_ = 0;
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size();
max_n1_ = max_n1_ < n1 ? n1 : max_n1_;
});
return max_n1_;
}();
array<array<index_t, max_n1>, n0> a_of_a{{-1}};
static_for<0, n0, 1>{}([&](auto i0) {
constexpr index_t n1 = t_of_s[i0].size();
static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; });
});
return a_of_a;
}
// Here should use MultiIndex<NSize>, instead of tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
}
template <typename... Ys,
typename X,
std::enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; });
return r;
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; });
return r;
}
// MultiIndex = scalar * MultiIndex
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; });
return r;
}
// MultiIndex = MultiIndex * scalar
template <
typename... Xs,
typename Y,
std::enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
{
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] / y[i]; }, number<NSize>{});
}
} // namespace ck_tile
#include <tuple>
// WARNING: needed by compiler for C++ structured binding support only, don't use this
namespace std {
template <typename... Ts>
struct tuple_size<ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, ck_tile::tuple<Ts...>> : std::tuple_element<I, std::tuple<Ts...>>
{
};
template <typename... Ts>
struct tuple_size<const ck_tile::tuple<Ts...>> : std::integral_constant<std::size_t, sizeof...(Ts)>
{
};
template <std::size_t I, typename... Ts>
struct tuple_element<I, const ck_tile::tuple<Ts...>>
: std::tuple_element<I, const std::tuple<Ts...>>
{
};
} // namespace std
#if 1
#define TO_TUPLE_OF_NUMBER(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::tuple<ck_tile::number<a[ck_tile::number<IDX_IDX_>{}]>...>{}; \
} \
(ck_tile::make_index_sequence<n>{}) _Pragma("clang diagnostic pop")
#else
#define TO_TUPLE_OF_NUMBER(arr, n_) \
[&arr, n_] { \
static_assert(arr.size() >= n_, "wrong! out of bound"); \
\
static_assert(n_ < 7, "not implemented"); \
\
if constexpr(n_ == 0) \
{ \
return ck_tile::tuple<>{}; \
} \
else if constexpr(n_ == 1) \
{ \
return ck_tile::tuple<number<arr[0]>>{}; \
} \
else if constexpr(n_ == 2) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>>{}; \
} \
else if constexpr(n_ == 3) \
{ \
return ck_tile::tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>>{}; \
} \
else if constexpr(n_ == 4) \
{ \
return ck_tile:: \
tuple<number<arr[0]>, number<arr[1]>, number<arr[2]>, number<arr[3]>>{}; \
} \
else if constexpr(n_ == 5) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>>{}; \
} \
else if constexpr(n_ == 6) \
{ \
return ck_tile::tuple<number<arr[0]>, \
number<arr[1]>, \
number<arr[2]>, \
number<arr[3]>, \
number<arr[4]>, \
number<arr[5]>>{}; \
} \
}()
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#pragma once
namespace ck_tile {
enum class bf16_rounding_mode
{
standard = 0, // rtn
truncate_with_nan,
truncate,
};
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
constexpr float bf16_to_float_raw(uint16_t x);
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double_raw(uint16_t x);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
using raw_type = uint16_t;
raw_type data;
CK_TILE_HOST_DEVICE
static constexpr bfloat16_t bit_cast(raw_type x)
{
bfloat16_t y;
y.data = x;
return y;
}
// constructor
constexpr bfloat16_t() : data() {}
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr bfloat16_t(const unsigned int& x)
: data(float_to_bf16_raw(static_cast<float>(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return bf16_to_float_raw(data); }
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return bf16_to_double_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
template <typename>
struct native_t;
template <>
struct native_t<bfloat16_t>
{
using type = ushort;
};
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
using bfloat16_t = ushort;
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
// round to nearest
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else
return float_to_bf16_truc_raw(f);
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
{
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
}
CK_TILE_HOST_DEVICE
constexpr float bf16_to_float_raw(uint16_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double_raw(uint16_t x)
{
return static_cast<double>(bf16_to_float_raw(x));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
template <class T>
struct numeric;
template <>
struct numeric<bfloat16_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
}
// maximum rounding error
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeeeeee mmmmmmm
// 0 01111110 0000000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
}
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
#endif
// math
CK_TILE_HOST_DEVICE
bfloat16_t abs(const bfloat16_t& x)
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
}
CK_TILE_HOST_DEVICE
bool isnan(const bfloat16_t& x)
{
uint16_t xx = bit_cast<bf16_raw_t>(x);
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
bfloat16_t sqrt(bfloat16_t x)
{
return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class fp8_rounding_mode
{
standard = 0,
stochastic
};
/*
* ______________NANOO_________________ | ______________IEEE________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
* Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
* 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
struct alignas(1) float8_e4m3_t
{
static constexpr int exponent = 4;
static constexpr int mantissa = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
CK_TILE_HOST_DEVICE
static constexpr float8_e4m3_t bit_cast(raw_type x)
{
float8_e4m3_t y;
y.data = x;
return y;
}
// constructor
constexpr float8_e4m3_t() : data() {}
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
{
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e4m3_t(const unsigned int& x)
: data(float_to_fp8_raw(static_cast<float>(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return fp8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
using fp8_t = float8_e4m3_t;
using fp8_raw_t = typename fp8_t::raw_type;
struct alignas(1) float8_e5m2_t
{
static constexpr int exponent = 5;
static constexpr int mantissa = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
CK_TILE_HOST_DEVICE
static constexpr float8_e5m2_t bit_cast(raw_type x)
{
float8_e5m2_t y;
y.data = x;
return y;
}
// constructor
constexpr float8_e5m2_t() : data() {}
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
{
}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr float8_e5m2_t(const unsigned int& x)
: data(float_to_bf8_raw(static_cast<float>(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return bf8_to_float_raw(data); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
using bf8_t = float8_e5m2_t;
using bf8_raw_t = typename bf8_t::raw_type;
template <typename>
struct native_t;
template <>
struct native_t<fp8_t>
{
using type = _BitInt(8);
};
template <>
struct native_t<bf8_t>
{
using type = unsigned _BitInt(8);
};
#else
using fp8_t = _BitInt(8);
using fp8_raw_t = uint8_t;
using bf8_t = unsigned _BitInt(8);
using bf8_raw_t = uint8_t;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
{
// fp8/bf8 exponent/mantissa layout
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant;
// original type exponent/mantissa layout
constexpr int in_exp = numeric_traits<X>::exp;
constexpr int in_mant = numeric_traits<X>::mant;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr Y nan_code =
numeric<Y>::quiet_NaN(); // __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
#else
constexpr Y nan_code = 0x80;
#endif
constexpr uint32_t nan_mask = numeric_traits<X>::nan_mask;
// convert to bitwise
using T_bitwise = typename numeric_traits<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
head = x_bitwise & numeric_traits<X>::head_mask;
mantissa = x_bitwise & numeric_traits<X>::mant_mask;
exponent = (head >> in_mant) & numeric_traits<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = numeric_traits<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
if constexpr(negative_zero_nan)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// check if x is 0.0
if(x_bitwise == 0)
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
bool odd =
mantissa &
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if(out_exponent == 0)
{
if((1 << in_mant) & mantissa)
{
out_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
}
}
else
{
if((1 << (in_mant + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
// No need to make 1 implicit now as it will be addressed later
}
}
mantissa >>= (in_mant - out_mant);
if(out_exponent > max_exp)
{
if(clip)
{
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
}
else
{
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return __builtin_bit_cast(
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
mantissa &= (1 << out_mant) - 1;
return __builtin_bit_cast(Y,
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
(out_exponent << out_mant) | mantissa));
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
{
// fp8/bf8 exponent/mantissa layout
constexpr int in_exp = numeric_traits<X>::exp;
constexpr int in_mant = numeric_traits<X>::mant;
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant;
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
// prepare the codes
constexpr uint8_t nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_traits<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = numeric_traits<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = numeric_traits<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_traits<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = numeric_traits<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x_raw == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x_raw >> (in_exp + in_mant);
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
int exponent = (x_raw & 0x7F) >> in_mant;
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
T_bitwise retval;
if constexpr(negative_zero_nan)
{
if(x_raw == nan_code)
return NaN;
}
else
{
if(x_raw == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((numeric_traits<Y>::mant == 10) && (numeric_traits<X>::mant == 2) && !negative_zero_nan)
{
retval = x_raw;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - in_mant);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= out_mant - in_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << out_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
{
// check datatype
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
}
} // namespace impl
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
// clang-format off
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
else return fp8_raw_t{0};
}
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return bf8_raw_t{0};
}
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(bit_cast<fp8_t>(x));
#endif
}
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bit_cast<bf8_t>(x));
#endif
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {})
{
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {})
{
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
{
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
}
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
{
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
}
// clang-format on
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<fp8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 8;
#else
static constexpr int bias = 7;
#endif
};
template <>
struct numeric_traits<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 16;
#else
static constexpr int bias = 15; // IEEE
#endif
};
template <class T>
struct numeric;
template <>
struct numeric<fp8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr fp8_t min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t max()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
}
// maximum rounding error
// bin : 7 6543 210
// bits: s eeee mmm
// 0 0110 000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
}
};
template <>
struct numeric<bf8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bf8_t min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t max()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
}
// maximum rounding error
// bin : 7 65432 10
// bits: s eeeee mm
// 0 01110 00 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t abs(const fp8_t& x)
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(bit_cast<fp8_raw_t>(x) & 0x7f));
}
CK_TILE_HOST_DEVICE
bool isnan(const fp8_t& x)
{
uint8_t xx = bit_cast<fp8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
}
CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
CK_TILE_HOST_DEVICE
bf8_t abs(const bf8_t& x)
{
return bit_cast<bf8_t>(static_cast<fp8_raw_t>(bit_cast<bf8_raw_t>(x) & 0x7f));
}
CK_TILE_HOST_DEVICE
bool isnan(const bf8_t& x)
{
uint8_t xx = bit_cast<bf8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
}
CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace ck_tile {
using fp16_hip_t = _Float16; // most of hip internal function use this type
using fp16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use fp16_hip_t as interchangable data type for float16
struct alignas(2) half_t
{
using raw_type = fp16_raw_t;
raw_type data;
CK_TILE_HOST_DEVICE
static constexpr half_t bit_cast(raw_type x)
{
half_t y;
y.data = x;
return y;
}
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
// constructor
constexpr half_t() : data{} {}
// construct from HIP half
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast<raw_type>(x)) {}
// construct from float
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {}
// construct from double
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {}
// construct from int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const int& x) : half_t(static_cast<fp16_hip_t>(__int2half_rn(x))) {}
// construct from unsigned int
CK_TILE_HOST_DEVICE
explicit constexpr half_t(const unsigned int& x)
: half_t(static_cast<fp16_hip_t>(__uint2half_rn(x)))
{
}
// cast to float
CK_TILE_HOST_DEVICE
explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); }
// cast to double
CK_TILE_HOST_DEVICE
explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); }
// cast to int
CK_TILE_HOST_DEVICE
explicit constexpr operator int() const
{
return static_cast<int>(fp16_to_float_hip(to_fp16()));
}
CK_TILE_HOST_DEVICE
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
template <typename>
struct native_t;
template <>
struct native_t<half_t>
{
using type = _Float16;
};
using fp16_t = half_t;
using fp16_raw_t = typename half_t::raw_type;
#else
using fp16_t = _Float16;
using half_t = _Float16;
using fp16_raw_t = ushort;
#endif
// conversions
CK_TILE_HOST_DEVICE
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
{
// return __half2float(x);
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
{
return static_cast<double>(fp16_to_float_hip(x));
}
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
{
return __float2half(x);
// return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
CK_TILE_HOST_DEVICE
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
// limits
template <class T>
struct numeric;
template <>
struct numeric<half_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr half_t min()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr half_t max()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
}
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeee mmmmmmmmmm
// 0 01110 0000000000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
}
CK_TILE_HOST_DEVICE static constexpr half_t zero()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// arithmetic
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
{
return __heq(x.to_fp16(), y.to_fp16());
}
CK_TILE_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
#if 0
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
return half_t(__hadd(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); }
CK_TILE_DEVICE
half_t operator-(const half_t& x, const half_t& y)
{
return half_t(__hsub(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator*(const half_t& x, const half_t& y)
{
return half_t(__hmul(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t operator/(const half_t& x, const half_t& y)
{
return half_t(__hdiv(x.to_fp16(), y.to_fp16()));
}
CK_TILE_DEVICE
half_t& operator+=(half_t& x, const half_t& y)
{
x = half_t(__hadd(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator-=(half_t& x, const half_t& y)
{
x = half_t(__hsub(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator*=(half_t& x, const half_t& y)
{
x = half_t(__hmul(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator/=(half_t& x, const half_t& y)
{
x = half_t(__hdiv(x.to_fp16(), y.to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator++(half_t& x)
{
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t& operator--(half_t& x)
{
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return x;
}
CK_TILE_DEVICE
half_t operator++(half_t& x, int)
{
half_t y(x);
x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
CK_TILE_DEVICE
half_t operator--(half_t& x, int)
{
half_t y(x);
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
#endif
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const half_t& x)
{
uint16_t xx = x.get();
return (xx & 0x7FFF) > 0x7C00;
}
CK_TILE_DEVICE
half_t sqrt(half_t x)
{
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace ck_tile {
using index_t = int32_t;
using long_index_t = int64_t;
using int8_t = int8_t;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
template <auto v>
struct constant
{
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
template <typename T, T v>
struct integral_constant : constant<v>
{
using value_type = T;
using type = integral_constant; // using injected-class-name
static constexpr T value = v;
// constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; }
// constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } //
};
template <index_t v>
using number = constant<v>;
template <long_index_t v>
using long_number = constant<v>;
template <bool b>
using bool_constant = constant<b>;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>) \
{ \
return constant<(OP x)>{}; \
}
#define CK_TILE_BINARY_OP(OP) \
template <auto x, auto y> \
CK_TILE_HOST_DEVICE constexpr auto operator OP(constant<x>, constant<y>) \
{ \
return constant<(x OP y)>{}; \
}
CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP(-)
CK_TILE_LEFT_UNARY_OP(~)
CK_TILE_LEFT_UNARY_OP(!)
CK_TILE_LEFT_UNARY_OP(*)
CK_TILE_BINARY_OP(+)
CK_TILE_BINARY_OP(-)
CK_TILE_BINARY_OP(*)
CK_TILE_BINARY_OP(/)
CK_TILE_BINARY_OP(%)
CK_TILE_BINARY_OP(&)
CK_TILE_BINARY_OP(|)
CK_TILE_BINARY_OP(^)
CK_TILE_BINARY_OP(<<)
CK_TILE_BINARY_OP(>>)
CK_TILE_BINARY_OP(&&)
CK_TILE_BINARY_OP(||)
CK_TILE_BINARY_OP(==)
CK_TILE_BINARY_OP(!=)
CK_TILE_BINARY_OP(>)
CK_TILE_BINARY_OP(<)
CK_TILE_BINARY_OP(>=)
CK_TILE_BINARY_OP(<=)
#undef CK_TILE_LEFT_UNARY_OP
#undef CK_TILE_BINARY_OP
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment