Unverified Commit db376dd8 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

introducing ck_tile! (#1216)

* enable gfx940

* switch between intrinsic mfma routines on mi100/200 and mi300

* fix mfma_int8 on MI300

* disable 2 int8 examples on MI300

* Update cmake-ck-dev.sh

* restore gitignore file

* modify Jenkinsfile to the internal repo

* Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx

Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0

)

---
updated-dependencies:
- dependency-name: rocm-docs-core
  dependency-type: direct:production
  update-type: version-update:semver-minor
...
Signed-off-by: default avatardependabot[bot] <support@github.com>

* initial enablement of gfx950

* fix clang format

* disable examples 31 and 41 int8 on gfx950

* add code

* fix build wip

* fix xx

* now can build

* naming

* minor fix

* wip fix

* fix macro for exp2; fix warpgemm a/b in transposedC

* unify as tuple_array

* Update the required Python version to 3.9

* Update executable name in test scripts

* re-structure tuple/array to avoid spill

* Merge function templates

* Fix format

* Add constraint to array<> ctor

* Re-use function

* Some minor changes

* remove wrong code in store_raw()

* fix compile issue in transpose

* Rename enum
Rename 'cood_transform_enum' to 'coord_transform_enum'

* let more integral_constant->constant, and formating

* make sure thread_buffer can be tuple/array

* temp fix buffer_store spill

* not using custom data type by default, now we can have ISA-level same code as opt_padding

* fix compile error, fp8 not ready now

* fix fp8 duplicated move/shift/and/or problem

* Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode

* fix scratch in fp8 kernel

* update some readme

* fix merge from upstream

* sync with upstream

* sync upstream again

* sync 22

* remove unused

* fix clang-format

* update README of ck_tile example

* fix several issue

* let python version to be 3.8 as minimal

* remove ck_tile example from default cmake target like all/install/check

* remove mistake

* 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg

* fix some bug in group-mode masking and codegen. update README

* F8 quantization for FMHA forward (#1224)

* Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline

* Add element function to fmha api

* Adjust P elementwise function

* Fix bug of elementwise op, our elementwise op is not inout

* Add some elementwise op, prepare to quantization

* Let generate.py can generate different elementwise function

* To prevent compiler issue, remove the elementwise function we have not used.

* Remove f8 pipeline, we should share the same pipeline even in f8

* Remove remove_cvref_t

* Avoid warning

* Fix wrong fp8 QK/KV block gemm setting

* Check fp8 rounding error in check_err()

* Set fp8 rounding error for check_err()

* Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode

* 1. codgen the f8 api and kernel
2. f8 host code

* prevent warning in filter mode

* Remove not-in-use elementwise function kargs

* Remove more not-in-use elementwise function kargs

* Small refinements in C++ source files

* Use conditional_t<> to simplify code

* Support heterogeneous argument for binary function types

* Re-use already-existing scales<> functor template

* Fix wrong value produced by saturating

* Generalize the composes<> template

* Unify saturates<> implementation

* Fix type errors in composes<>

* Extend less_equal<>

* Reuse the existing template less_equal<> in check_err()

* Add equal<float> & equal<double>

* Rename check_err() parameter

* Rename check_err() parameter

* Add FIXME comment for adding new macro in future

* Remove unnecessary cast to void

* Eliminate duplicated code

* Avoid dividing api pool into more than 2 groups

* Use more clear variable names

* Use affirmative condition in if stmt

* Remove blank lines

* Donot perfect forwarding in composes<>

* To fix compile error, revert generate.py back to 4439cc107dd90302d68a6494bdd33113318709f8

* Fix bug of p element function

* Add compute element op to host softmax

* Remove element function in api interface

* Extract user parameter

* Rename pscale and oscale variable

* rename f8 to fp8

* rename more f8 to fp8

* Add pipeline::operator() without element_functor

* 1. Remove deprecated pipeline enum
2. Refine host code parameter

* Use quantization range as input

* 1. Rename max_dtype to dtype_max.
2. Rename scale to scale_s
3.Add init description

* Refine description

* prevent early return

* unify _squant kernel name in cpp, update README

* Adjust the default range.

* Refine error message and bias range

* Add fp8 benchmark and smoke test

* fix fp8 swizzle_factor=4 case

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>

---------
Signed-off-by: default avatardependabot[bot] <support@github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarJing Zhang <jizha@amd.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatardependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent dd34ab6e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include <type_traits>
#include <stdint.h>
#include <cmath>
namespace ck_tile {
template <typename Scale, Scale lhs>
struct scales_c
{
template <typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
template <typename Scale>
struct scales
{
static_assert(std::is_copy_constructible_v<Scale>);
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
template <typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
-> decltype(std::declval<const Scale&>() * rhs)
{
return lhs_ * rhs;
}
private:
Scale lhs_;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename Scale>
__host__ __device__ scales(Scale)->scales<Scale>;
template <typename Left = void, typename Right = Left>
struct plus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
{
return lhs + rhs;
}
};
template <>
struct plus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs + rhs)
{
return lhs + rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ plus()->plus<void, void>;
template <typename Left = void, typename Right = Left>
struct minus
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
{
return lhs - rhs;
}
};
template <>
struct minus<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs - rhs)
{
return lhs - rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ minus()->minus<void, void>;
template <typename Left = void, typename Right = Left>
struct multiplies
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
template <>
struct multiplies<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ multiplies()->multiplies<void, void>;
template <typename T>
struct maximize
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <typename T>
struct minimize
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};
template <typename T>
struct integer_divide_ceiler
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const
{
static_assert(std::is_same<T, index_t>{} || std::is_same<T, int>{}, "wrong type");
return (a + b - number<1>{}) / b;
}
};
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
{
return (x + y - number<1>{}) / y;
}
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
{
return y * integer_divide_ceil(x, y);
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <>
CK_TILE_DEVICE constexpr float max(float x, float y)
{
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
}
template <>
CK_TILE_DEVICE constexpr double max(double x, double y)
{
return __builtin_fmax(x, y); // maybe still v_max3_f32
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
{
return X > y ? X : y;
}
template <index_t Y>
CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number<Y>)
{
return x > Y ? x : Y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x)
{
return x;
}
template <typename T>
CK_TILE_HOST constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <>
CK_TILE_DEVICE constexpr float min(float x, float y)
{
return __builtin_fminf(x, y);
}
template <>
CK_TILE_DEVICE constexpr double min(double x, double y)
{
return __builtin_fmin(x, y);
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
{
return X < y ? X : y;
}
template <index_t Y>
CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number<Y>)
{
return x < Y ? x : Y;
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); }
CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); }
// greatest common divisor, aka highest common factor
CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
}
template <index_t X, index_t Y>
CK_TILE_HOST_DEVICE constexpr auto gcd(number<X>, number<Y>)
{
constexpr auto r = gcd(X, Y);
return number<r>{};
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
}
// least common multiple
template <typename X, typename Y>
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
{
return (x * y) / gcd(x, y);
}
template <typename X,
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
}
template <typename Left = void, typename Right = Left>
struct equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
{
return lhs == rhs;
}
};
template <>
struct equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs == rhs)
{
return lhs == rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ equal()->equal<void, void>;
template <>
struct equal<float, float>
{
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
{
return bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
}
};
template <>
struct equal<double, double>
{
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
{
return bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
}
};
template <typename Left = void, typename Right = Left>
struct less
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
{
return lhs < rhs;
}
};
template <>
struct less<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs < rhs)
{
return lhs < rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less()->less<void, void>;
template <typename Left = void, typename Right = Left>
struct less_equal
{
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
};
template <>
struct less_equal<void, void>
{
template <typename Left, typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const
-> decltype(lhs <= rhs)
{
return lhs <= rhs;
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less_equal()->less_equal<void, void>;
template <>
struct less_equal<float, float>
{
CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const
{
return lhs < rhs || bit_cast<uint32_t>(lhs) == bit_cast<uint32_t>(rhs);
}
};
template <>
struct less_equal<double, double>
{
CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const
{
return lhs < rhs || bit_cast<uint64_t>(lhs) == bit_cast<uint64_t>(rhs);
}
};
CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x)
{
// TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
return 1 << (32 - clz(x - 1));
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two()
{
constexpr index_t y = next_power_of_two(X);
return number<y>{};
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number<X>)
{
constexpr index_t y = next_power_of_two(X);
return number<y>{};
}
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return 31 - __builtin_clz(x);
}
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
{
// TODO: x need to be 1 ~ 0x7fffffff
return x == (1 << integer_log2_floor(x));
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
template <typename T>
struct log2e;
template <>
struct log2e<double>
{
static constexpr double value = C_LOG2E;
};
template <>
struct log2e<float>
{
static constexpr float value = C_LOG2E;
};
template <typename T = double>
constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
float abs(const float& x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
}
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
uint32_t xx = bit_cast<uint32_t>(x);
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace ck_tile {
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
//
template <typename T>
struct numeric
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr T round_error()
{
return std::numeric_limits<T>::round_error();
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
{
return std::numeric_limits<T>::signaling_NaN();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
{
return std::numeric_limits<T>::denorm_min();
}
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
CK_TILE_HOST_DEVICE static constexpr T log2e()
{
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
{
return static_cast<T>(C_LOG2E);
}
else
{
return 0; // TODO: integer?
}
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
} // namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace ck_tile {
#if CK_TILE_USE_CUSTOM_DATA_TYPE
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
{
return static_cast<Y>(x);
}
#else
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using non_const_y = std::remove_const_t<Y>;
using non_const_x = std::remove_const_t<X>;
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x); \
}
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
#undef CK_TILE_TYPE_CONVERT
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template <typename T>
struct native_t
{
using type = remove_cvref_t<T>;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace impl {
template <typename T_, index_t N_>
struct ext_vector
{
static constexpr index_t N = N_;
using value_type = typename native_t<remove_cvref_t<T_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
template <typename V_, index_t Vs_, index_t N_>
struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))), N_>
{
static constexpr index_t N = Vs_ * N_;
using value_type = typename native_t<remove_cvref_t<V_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
} // namespace impl
template <typename T, index_t N>
using ext_vector_t = typename impl::ext_vector<T, N>::type;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template <typename T>
struct vector_traits
{
using scalar_type = remove_cvref_t<T>;
static constexpr index_t vector_size = 1;
};
// specialization for ext_vector_type()
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
template <typename X, typename Y>
using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using fp64_t = double;
using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8)));
using fp32x16_t = float __attribute__((ext_vector_type(16)));
using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
// using fp16_t = ...
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
// using uint16_t
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// using int8_t
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else
// f8
// using fp8_t
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
// buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default>
struct buffer_view;
// Address Space: generic
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
struct buffer_view<address_space_enum::generic,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum::coherence_default>
{
using type = T;
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
{
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::generic;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return tmp;
#else
return *c_style_pointer_cast<const X*>(&p_data_[i]);
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
}
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template <typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
struct buffer_view<address_space_enum::global,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>
{
using type = T;
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
{
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::global;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
#if CK_TILE_USE_AMD_BUFFER_LOAD
bool constexpr use_amd_buffer_addressing = true;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(InvalidElementUseNumericalZeroValue)
{
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check>(
p_data_, i, is_valid_element, buffer_size_);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<
remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check>(
p_data_, i, is_valid_element, buffer_size_, invalid_element_value_);
}
}
else
{
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return tmp;
#else
return *c_style_pointer_cast<const X*>(&p_data_[i]);
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
{
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
dst, p_data_, i, buffer_size_, is_valid_element);
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_);
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
#if CK_TILE_USE_AMD_BUFFER_STORE
bool constexpr use_amd_buffer_addressing = true;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
x, p_data_, i, is_valid_element, buffer_size_);
}
else
{
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
x, p_data_, i, is_valid_element, buffer_size_);
}
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
{
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_);
}
else
{
if(is_valid_element)
{
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
}
}
}
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
bool constexpr use_amd_buffer_addressing = std::is_same_v<remove_cvref_t<scalar_t>, double>;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_);
}
else if(is_valid_element)
{
atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
}
}
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
struct buffer_view<address_space_enum::lds,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum::coherence_default>
{
using type = T;
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
{
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::lds;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return tmp;
#else
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]);
return bit_cast<X>(rtn);
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
#if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
bool constexpr workaround_int8_ds_write_issue = true;
#else
bool constexpr workaround_int8_ds_write_issue = false;
#endif
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
int8_t>::value &&
workaround_int8_ds_write_issue)
{
if(is_valid_element)
{
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
static_assert((std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value) ||
(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x2_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8x4_t>::value &&
std::is_same<remove_cvref_t<X>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8x8_t>::value &&
std::is_same<remove_cvref_t<X>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(std::is_same<remove_cvref_t<T>, int8x16_t>::value &&
std::is_same<remove_cvref_t<X>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
}
}
else
{
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>;
*c_style_pointer_cast<buf_t*>(&p_data_[i]) = reinterpret_cast<const buf_t&>(x);
#endif
}
}
}
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
struct buffer_view<address_space_enum::vgpr,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum::coherence_default>
{
using type = T;
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
{
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::vgpr;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
return tmp;
#else
return *c_style_pointer_cast<const X*>(&p_data_[i]);
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
return X{invalid_element_value_};
}
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
if(is_valid_element)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#endif
}
}
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename T,
typename BufferSizeType>
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size)
{
return buffer_view<BufferAddressSpace, T, BufferSizeType, true, Coherence>{p, buffer_size};
}
template <address_space_enum BufferAddressSpace,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename T,
typename BufferSizeType,
typename X,
typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
{
return buffer_view<BufferAddressSpace, T, BufferSizeType, false, Coherence>{
p, buffer_size, invalid_element_value};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
{
return tile_window.async_load(lds_tile);
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename WindowLengths>
CK_TILE_DEVICE auto load_tile(const null_tile_window<WindowLengths>&)
{
return null_tensor{};
}
template <typename T, typename WindowLengths>
CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<WindowLengths>&)
{
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
struct null_tensor
{
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
namespace ck_tile {
// placeholder type if we want to opt-out a tile window parameter
template <typename WindowLengths_>
struct null_tile_window
{
using BottomTensorView = null_tensor_view;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using BottomTensorIndex = array<index_t, WindowLengths::size()>;
CK_TILE_DEVICE constexpr null_tile_window() = default;
CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths)
: window_lengths_{window_lengths}
{
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
WindowLengths window_lengths_;
};
// utility to check if this is a Null Tile Window
namespace impl {
template <typename>
struct is_null_tile_window : public std::false_type
{
};
template <typename T>
struct is_null_tile_window<null_tile_window<T>> : public std::true_type
{
};
} // namespace impl
template <typename T>
CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&)
{
return impl::is_null_tile_window<remove_cvref_t<T>>::value;
}
template <typename WindowLengths>
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths, typename... Ts>
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
const WindowLengths& window_lengths,
const multi_index<WindowLengths::size()>& /*origin*/,
Ts&&...)
{
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
return null_tile_window<remove_cvref_t<WindowLengths>>{window_lengths};
}
template <typename WindowLengths>
CK_TILE_DEVICE void
move_tile_window(null_tile_window<WindowLengths>&,
const typename null_tile_window<WindowLengths>::BottomTensorIndex&)
{
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace ck_tile {
namespace detail {
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
{
constexpr auto I0 = number<0>{};
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_major_minor_to_y_({rh_major, rh_minor}) = i;
});
return rh_major_minor_to_y_;
};
constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
}
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
// # of vectors
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using SFC_Y = space_filling_curve<decltype(y_lengths),
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Y::get_num_of_access();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_array(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}
} // namespace detail
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
{
using InDataType = typename InTensor::DataType;
using OutDataType = typename OutTensor::DataType;
using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
// type convert
const auto in_tmp = tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
// shuffle
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
InDstrEncode::NDimY == OutDstrEncode::NDimY)
{
detail::shuffle_tile_impl_in_thread(out, in_tmp);
}
else
{
// NOT implemented
}
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
get_slice_tile(const tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using TileWindow = tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>;
// NOTE: This API will override the origin of the tile window!
static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds));
static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension());
constexpr auto slice_lengths = slice_ends - slice_begins;
return make_tile_window(tile.get_bottom_tensor_view(),
sequence_to_tuple_of_number(slice_lengths),
to_multi_index(slice_begins));
}
template <typename DataType_,
typename StaticTileDistribution_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
get_slice_tile(const static_distributed_tensor<DataType_, StaticTileDistribution_>& tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using DataType = remove_cvref_t<DataType_>;
using Distribution = remove_cvref_t<StaticTileDistribution_>;
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
sliced_tensor.get_thread_buffer() =
tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths);
return sliced_tensor;
}
template <typename DstDataType_,
typename DstStaticTileDistribution_,
typename SrcDataType_,
typename SrcStaticTileDistribution_,
index_t... SliceBegins,
index_t... SliceEnds>
CK_TILE_DEVICE constexpr auto
set_slice_tile(static_distributed_tensor<DstDataType_, DstStaticTileDistribution_>& dst_tile,
const static_distributed_tensor<SrcDataType_, SrcStaticTileDistribution_>& src_tile,
sequence<SliceBegins...> slice_begins,
sequence<SliceEnds...> slice_ends)
{
using DstDistribution = remove_cvref_t<DstStaticTileDistribution_>;
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>();
static_assert(std::is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer());
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace ck_tile {
template <typename DataType_, typename StaticTileDistribution_>
struct static_distributed_tensor
{
using DataType = remove_cvref_t<DataType_>;
using StaticTileDistribution = remove_cvref_t<StaticTileDistribution_>;
static_assert(StaticTileDistribution::is_static(),
"wrong! StaticTileDistribution should be known at compile tile");
using ThreadTensorDesc =
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{
return StaticTileDistribution::get_num_of_dimension_x();
}
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
return StaticTileDistribution::get_lengths();
}
CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution()
{
return StaticTileDistribution{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
return StaticTileDistribution::get_distributed_spans();
}
CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); }
CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; }
CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size()
{
return kThreadElementSpaceSize;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths>
CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>) const
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
sliced_thread_data(number<sliced_thread_tensor_desc.calculate_offset(idx)>{}) =
thread_buf_[number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}];
});
return sliced_thread_data;
}
template <index_t... YSliceOrigins, index_t... YSliceLengths, typename SlicedThreadData>
CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence<YSliceOrigins...>,
sequence<YSliceLengths...>,
const SlicedThreadData& sliced_thread_data)
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
sizeof...(YSliceLengths) == StaticTileDistribution::NDimY,
"wrong!");
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
constexpr auto idx_ys = idx + sequence<YSliceOrigins...>{};
thread_buf_(number<ThreadTensorDesc{}.calculate_offset(idx_ys)>{}) =
sliced_thread_data[number<sliced_thread_tensor_desc.calculate_offset(idx)>{}];
});
}
template <typename TileDistributedIndices>
CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_[number<ThreadTensorDesc{}.calculate_offset(y_idx)>{}];
}
template <typename TileDistributedIndices>
CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices)
{
static_assert(is_static_v<TileDistributedIndices>,
"wrong! Tile Distributed Indices should be static");
constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices(
TileDistributedIndices{});
return thread_buf_(number<ThreadTensorDesc{}.calculate_offset(y_idx)>{});
}
//
thread_buffer<DataType, kThreadElementSpaceSize> thread_buf_;
};
template <typename DataType, typename StaticTileDistribution>
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&)
{
return static_distributed_tensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{};
}
template <typename DataType, typename StaticTileDistribution, typename ThreadBuffer>
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&,
ThreadBuffer&& thread_buffer_)
{
return static_distributed_tensor<remove_cvref_t<DataType>,
remove_cvref_t<StaticTileDistribution>>{thread_buffer_};
}
// get X indices from tuple of tile_distributed_index<>
template <typename StaticTileDistribution, typename DistributedIndices>
CK_TILE_HOST_DEVICE constexpr auto
get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution,
DistributedIndices distributed_indices)
{
const auto partition_index = detail::get_partition_index(tile_distribution);
constexpr auto y_indices =
tile_distribution.get_y_indices_from_distributed_indices(distributed_indices);
const auto x_coord = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(partition_index, to_array<ck_tile::index_t, y_indices.size()>(y_indices)));
return x_coord.get_bottom_index();
}
template <typename DataType, typename StaticTileDistribution, typename XIndicesPredicate>
CK_TILE_HOST_DEVICE void
set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_tensor,
DataType value,
XIndicesPredicate predicate)
{
constexpr auto out_spans =
static_distributed_tensor<DataType, StaticTileDistribution>::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{},
distributed_indices);
if(predicate(x_indices))
{
out_tensor(distributed_indices) = value;
}
});
});
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
using TileDstr = remove_cvref_t<TileDistribution_>;
static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
constexpr auto tile_dstr = TileDstr{};
auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
tile_window_tmp.get_window_lengths(),
tile_window_tmp.get_window_origin(),
tile_dstr);
tile_window.store_raw(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor);
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// sweep over a span of a distribted tile and apply lambda function F
template <typename TileDistributedSpan_, // tile_distributed_span<...>
typename F // signature: F(tile_distributed_index<...>)
>
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
{
using DstrSpan = remove_cvref_t<TileDistributedSpan_>;
static_ford<typename DstrSpan::Impl>{}([&](auto dstr_idx_impl) {
constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl);
f(dstr_idx);
});
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
namespace ck_tile {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
struct tensor_adaptor
{
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform()
{
return Transforms::size();
}
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
{
return LowerDimensionHiddenIdss{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
{
return UpperDimensionHiddenIdss{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids()
{
return BottomDimensionHiddenIds{};
}
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
{
return TopDimensionHiddenIds{};
}
CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
constexpr auto tmp = get_transform_and_its_upper_dimension(number<idim_hidden>{});
constexpr index_t itran = tmp[number<0>{}];
constexpr index_t idim_up = tmp[number<1>{}];
constexpr bool found = tmp[number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
return length;
},
number<ndim_top_>{});
// TODO: make container_reduce support tuple of number and index_t
return container_reduce(lengths, multiplies{}, number<1>{});
}
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE static constexpr auto
get_transform_and_its_upper_dimension(number<IDimHidden>)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == IDimHidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension()
{
return BottomDimensionHiddenIds::size();
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension()
{
return TopDimensionHiddenIds::size();
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
less<index_t>,
equal<index_t>>::type;
return unique_sort_all_dim_ids::size();
}
constexpr static index_t ntransform_ = get_num_of_transform();
constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension();
constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension();
constexpr static index_t ndim_top_ = get_num_of_top_dimension();
using HiddenIndex = multi_index<ndim_hidden_>;
using BottomIndex = multi_index<ndim_bottom_>;
using TopIndex = multi_index<ndim_top_>;
// may be index_t or number<>
using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{initialize_element_size(transforms)}
{
static_assert(Transforms::size() == ntransform_ &&
LowerDimensionHiddenIdss::size() == ntransform_ &&
UpperDimensionHiddenIdss::size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
// FIXME: this logic is wrong when getting bottome dimension lengths
template <index_t IDimHidden>
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number<IDimHidden>) const
{
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
constexpr auto tmp = get_transform_and_its_upper_dimension(number<IDimHidden>{});
constexpr index_t itran = tmp[number<0>{}];
constexpr index_t idim_up = tmp[number<1>{}];
constexpr bool found = tmp[number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
}
template <index_t IDimTop>
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number<IDimTop> idim_top) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
CK_TILE_HOST_DEVICE constexpr index_t
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
}
#endif
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
{
return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
number<ndim_top_>{});
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
number<ndim_bottom_>{});
}
#endif
template <typename TopIdx>
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
{
static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = get_num_of_transform();
constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
multi_index<ndim_hidden> idx_hidden;
// initialize uppest index
set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
auto itran = itran_p1 - number<1>{};
const auto& tran = get_transforms().at(itran);
constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
multi_index<dims_low.size()> idx_low;
tran.calculate_lower_index(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
bool is_known = true;
static_for<0, Transforms::size(), 1>{}([&](auto i) {
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
});
return is_known && ck_tile::is_known_at_compile_time<ElementSize>::value;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides(
const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
{
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, get_num_of_transform(), 1>{}([&](auto itran) {
constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
const auto up_guaranteed_vector_lengths =
get_container_subset(guaranteed_vector_lengths, up_dims);
const auto up_guaranteed_vector_strides =
get_container_subset(guaranteed_vector_strides, up_dims);
// only need type of transform
auto [up_vector_lengths, up_vector_strides] =
Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
get_container_subset(vector_lengths, low_dims),
get_container_subset(vector_strides, low_dims));
if constexpr(up_dims.size() > 0)
{
for(index_t i = 0; i < up_dims.size(); ++i)
{
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
? up_guaranteed_vector_lengths[i]
: up_vector_lengths[i];
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
? up_guaranteed_vector_strides[i]
: up_vector_strides[i];
}
}
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_adaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::size();
static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
UpperDimensionNewTopIdss::size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number<ndim_old_top>{};
return tensor_adaptor<remove_cvref_t<Transforms>,
remove_cvref_t<decltype(low_dim_hidden_idss)>,
remove_cvref_t<decltype(up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
{
template <typename I>
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
return number<Tran::get_num_of_upper_dimension()>{};
}
};
template <typename OldTensorAdaptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
// sanity check
{
static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldTopIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_top_ids) constexpr {
return transform_sequences(
// convert lower dimension top id to hidden id
[](auto low_dim_top_id) constexpr {
return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
},
low_dim_top_ids);
},
NewLowerDimensionOldTopIdss{});
constexpr index_t num_new_transform = NewTransforms::size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus<index_t>{}, number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
constexpr auto new_top_dim_hidden_ids =
unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
// put everything together
const auto all_transforms =
container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
return tensor_adaptor<
remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
}
template <typename TensorAdaptor0, typename TensorAdaptor1>
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1)
{
static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
TensorAdaptor1::get_num_of_bottom_dimension(),
"wrong!");
// all_transforms = transform0 + transform1
const auto all_transforms =
container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
// shift
constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id_ =
max(adaptor0_max_hidden_id_,
TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
});
constexpr index_t ndim_up =
TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id_ =
max(adaptor0_max_hidden_id_,
TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
});
});
return adaptor0_max_hidden_id_;
}();
constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id_ =
min(adaptor1_min_hidden_id_,
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
});
});
return adaptor1_min_hidden_id_;
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_low_1 =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
constexpr auto low_dim_hidden_ids_1 =
TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
// if this low dim is bottom dim, then do id matching
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::get_bottom_dimension_hidden_ids()
[idim_bottom_1])
{
low_dim_hidden_ids_1_mod_(idim_low_1) =
TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
}
});
});
return low_dim_hidden_ids_1_mod_;
}
();
return generate_sequence_v2(
[&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
number<ndim_low_1>{});
},
number<TensorAdaptor1::get_num_of_transform()>{});
constexpr auto all_low_dim_hidden_idss =
container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_up_1 =
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
constexpr auto up_dim_hidden_ids_1 =
TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod_;
}
();
// constexpr tuple to sequence
return generate_sequence_v2(
[&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
number<ndim_up_1>{});
},
number<TensorAdaptor1::get_num_of_transform()>{});
constexpr auto all_up_dim_hidden_idss =
container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
// put everything together
return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
template <typename X,
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck_tile
// Macro function
// construct constexpr tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
\
return make_pass_through_transform(low_len); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
auto left_pad = meta_data.template pop<index_t>(pos); \
auto right_pad = meta_data.template pop<index_t>(pos); \
\
return make_pad_transform(low_len, left_pad, right_pad); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
auto coefficients = \
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_embed_transform(up_lens, coefficients); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
index_t pos = 0; \
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
\
return make_merge_transform(low_lens); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_unmerge_transform(up_lens); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_replicate_transform(up_lens); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// Macro function
// construct static tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
return make_pass_through_transform(number<low_len>{}); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
constexpr index_t left_pad = \
meta_data.template get<index_t>(sizeof(low_len)); \
\
constexpr index_t right_pad = \
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
\
return make_pad_transform( \
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
constexpr auto coefficients = \
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
\
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
constexpr auto low_lens = \
meta_data.template get<array<index_t, num_low_dim>>(0); \
\
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
struct tensor_adaptor_coordinate
{
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size();
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
using HiddenIndex = multi_index<NDimHidden>;
using BottomIndex = multi_index<ndim_bottom_>;
using TopIndex = multi_index<ndim_top_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default;
CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_top_index() const
{
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const
{
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
}
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; }
CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; }
//
HiddenIndex idx_hidden_;
};
template <typename Adaptor, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
const TopIndex& idx_top)
{
static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = Adaptor::get_num_of_transform();
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids();
constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids();
multi_index<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, top_dim_ids, idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - number<1>{};
const auto& tran = adaptor.get_transforms().at(itran);
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
multi_index<dims_low.size()> idx_low;
tran.calculate_lower_index(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return tensor_adaptor_coordinate<ndim_hidden,
remove_cvref_t<decltype(bottom_dim_ids)>,
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
}
template <bool JudgeDoTransforms = true,
typename Adaptor,
typename AdaptorCoord,
typename TopIndex,
typename BottomIndex>
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top,
BottomIndex& idx_diff_bottom)
{
constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension();
constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension();
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
constexpr index_t ntransform = Adaptor::get_num_of_transform();
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
if constexpr(JudgeDoTransforms)
{
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
multi_index<ndim_top> non_zero_diff_pick_top;
static_for<0, ndim_top, 1>{}(
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
set_container_subset(
is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
multi_index<dims_low.size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
}
else
{
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
}
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize top index diff
set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top);
// this is what needs to be updated
auto& idx_hidden = coord.get_hidden_index();
// update top index
auto idx_hidden_pick_top =
get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids());
idx_hidden_pick_top += idx_diff_top;
set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(do_transforms[itran])
{
const auto& tran = adaptor.get_transforms().at(itran);
constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran);
constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
multi_index<dims_low.size()> idx_diff_low;
tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
// set bottom index diff
idx_diff_bottom =
get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids());
}
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top)
{
constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
multi_index<ndim_bottom> tmp;
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
}
template <typename Adaptor, typename AdaptorCoord>
CK_TILE_HOST_DEVICE constexpr bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
const AdaptorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = Adaptor::get_num_of_transform();
const auto& idx_hidden = coord.get_hidden_index();
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
const auto tran = adaptor.get_transforms().at(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index())
{
const auto idx_up = get_container_subset(
idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up);
}
});
return valid;
}
template <typename Adaptor, typename AdpatorCoord>
CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
const AdpatorCoord& coord)
{
// check top index
const auto& idx_top = coord.get_top_index();
bool is_top_index_valid = true;
static_for<0, Adaptor::get_num_of_dimension(), 1>{}(
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
is_top_index_valid =
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i));
});
// check other hidden index
return is_top_index_valid &&
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct tensor_coordinate
: public tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>
{
using Base = tensor_adaptor_coordinate<NDimHidden, sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size();
using HiddenIndex = multi_index<NDimHidden>;
using TopIndex = multi_index<ndim_top_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default;
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden)
: Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); }
CK_TILE_HOST_DEVICE constexpr index_t get_offset() const
{
return Base::get_bottom_index()[number<0>{}];
}
CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const
{
return Base::get_hidden_index();
}
CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); }
};
template <typename TensorDesc, typename TopIndex>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return tensor_coordinate<TensorDesc::get_num_of_hidden_dimension(),
remove_cvref_t<decltype(TensorDesc::get_top_dimension_hidden_ids())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
CK_TILE_HOST_DEVICE constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
// TopDimensionHiddenIds> : sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths_,
typename GuaranteedVectorSrides_>
struct tensor_descriptor : public tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
sequence<0>,
TopDimensionHiddenIds>
{
using Base = tensor_adaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
sequence<0>,
TopDimensionHiddenIds>;
using ElementSpaceSizeType = ElementSpaceSize;
constexpr static index_t ntransform_ = Base::get_num_of_transform();
constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension();
constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension();
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ &&
GuaranteedVectorStrides::size() == ndim_hidden_,
"wrong! inconsistent # of hidden dimensions");
using TopIndex = multi_index<ndim_top_>;
using HiddenIndex = multi_index<ndim_hidden_>;
public:
CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default;
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms,
ElementSpaceSize element_space_size)
: Base{transforms}, element_space_size_{element_space_size}
{
static_assert(Transforms::size() == ntransform_ &&
LowerDimensionHiddenIdss::size() == ntransform_ &&
UpperDimensionHiddenIdss::size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
// construct from tensor_adaptor base class
CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor,
ElementSpaceSize element_space_size)
: Base{adaptor}, element_space_size_{element_space_size}
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
{
return Base::get_num_of_top_dimension();
}
template <index_t IDim>
CK_TILE_HOST_DEVICE constexpr auto get_length(number<IDim> idim) const
{
return Base::get_top_dimension_length(idim);
}
CK_TILE_HOST_DEVICE constexpr auto get_lengths() const
{
return Base::get_top_dimension_lengths();
}
CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const
{
return element_space_size_;
}
template <typename Idx>
CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const
{
return Base::calculate_bottom_index(idx)[number<0>{}];
}
// TODO make these private
CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const
{
return Base::get_transforms();
}
CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss()
{
return Base::get_lower_dimension_hidden_idss();
}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss()
{
return Base::get_upper_dimension_hidden_idss();
}
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids()
{
return Base::get_top_dimension_hidden_ids();
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
return Base::is_known_at_compile_time() &&
ck_tile::is_known_at_compile_time<ElementSpaceSize>::value;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides()
{
return Base::get_top_dimension_safe_vector_length_strides(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_descriptor{");
// tensor_adaptor
Base::print();
printf(", ");
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
printf("}");
}
// TODO make these private
ElementSpaceSize element_space_size_;
};
template <typename Adaptor, typename ElementSpaceSize>
CK_TILE_HOST_DEVICE constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
const ElementSpaceSize& element_space_size)
{
constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension();
return tensor_descriptor<remove_cvref_t<decltype(adaptor.get_transforms())>,
remove_cvref_t<decltype(adaptor.get_lower_dimension_hidden_idss())>,
remove_cvref_t<decltype(adaptor.get_upper_dimension_hidden_idss())>,
remove_cvref_t<decltype(adaptor.get_top_dimension_hidden_ids())>,
remove_cvref_t<decltype(element_space_size)>,
typename uniform_sequence_gen<NDimHidden, -1>::type,
typename uniform_sequence_gen<NDimHidden, -1>::type>{
adaptor, element_space_size};
}
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
CK_TILE_HOST_DEVICE constexpr auto
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
const auto element_space_size = old_tensor_desc.get_element_space_size();
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
new_transforms,
NewLowerDimensionOldTopIdss{},
NewUpperDimensionNewTopIdss{});
constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension();
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension();
using NewGuaranteedVectorLengths = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorLengths,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
using NewGuaranteedVectorStrides = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorStrides,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
return tensor_descriptor<
remove_cvref_t<decltype(new_tensor_adaptor.get_transforms())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_lower_dimension_hidden_idss())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_upper_dimension_hidden_idss())>,
remove_cvref_t<decltype(new_tensor_adaptor.get_top_dimension_hidden_ids())>,
remove_cvref_t<decltype(element_space_size)>,
NewGuaranteedVectorLengths,
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
}
namespace detail {
template <typename Lengths, typename Strides, index_t I, typename AccOld>
CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i];
if constexpr(i.value < Lengths::size() - 1)
{
return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new);
}
else
{
return acc_new;
}
}
} // namespace detail
/*
* These functions create naive tensor descriptor
*/
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template <typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size =
detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorStride>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
// tensor descriptor with offset, the offset will not be added into element space size
// only have an information of the starting offset, and will impact on offset calculation
template <typename... Lengths,
typename... Strides,
typename offset,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_with_offset(const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
const offset& os,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = detail::calculate_element_space_size_impl(
lengths, strides, number<0>{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, os));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
constexpr auto visible_dim_hidden_ids = sequence<1>{};
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorStride>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}();
constexpr index_t N = sizeof...(Lengths);
return transform_tensor_descriptor(
desc_0,
make_tuple(make_embed_transform(lengths, strides)),
make_tuple(sequence<0>{}),
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_unmerge_transform(lengths));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
template <typename... Lengths,
typename... Strides,
typename Offset,
index_t GuaranteedLastDimensionVectorLength = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset(
const tuple<Lengths...>& lengths,
const Offset& offset,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{});
constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{});
constexpr auto visible_dim_hidden_ids = sequence<1>{};
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type,
sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<1, -1>::type, sequence<1>>::type;
return tensor_descriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}();
constexpr index_t N = sizeof...(Lengths);
return transform_tensor_descriptor(
desc_0,
make_tuple(make_unmerge_transform(lengths)),
make_tuple(sequence<0>{}),
make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) number<>
template <typename... Lengths, typename Align>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align align)
{
constexpr auto I1 = number<1>{};
constexpr index_t N = sizeof...(Lengths);
const auto stride_n_minus_2 = integer_least_multiple(lengths[number<N - 1>{}], align);
auto strides = generate_tuple(
[&](auto i) {
if constexpr(i.value == N - 1)
{
return I1;
}
else if constexpr(i.value == N - 2)
{
return number<stride_n_minus_2>{};
}
else
{
return container_reduce(
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
}
},
number<N>{});
return make_naive_tensor_descriptor(lengths, strides);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BufferView_, typename TensorDesc_>
struct tensor_view
{
using buffer_view = remove_reference_t<BufferView_>;
using DataType = typename buffer_view::type;
using TensorDesc = remove_cvref_t<TensorDesc_>;
using TensorIndex = array<index_t, TensorDesc::get_num_of_top_dimension()>;
using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{}));
CK_TILE_HOST_DEVICE constexpr tensor_view() = default;
CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view,
const TensorDesc& desc)
: buf_{buffer_view}, desc_{desc}
{
}
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
{
return TensorDesc::get_num_of_top_dimension();
}
CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; }
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
template <typename X,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
// buf_
printf("buf_: ");
print(buf_);
printf(", ");
// desc_
printf("desc_: ");
print(desc_);
printf("}");
}
// member
buffer_view buf_;
TensorDesc desc_;
};
// placeholder type if we want to opt-out a tile view parameter
struct null_tensor_view
{
};
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
const tensor_descriptor<Ts...>& desc)
{
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
number<GuaranteedLastDimensionVectorStride> = number<-1>{})
{
auto desc = make_naive_tensor_descriptor(lengths,
strides,
number<GuaranteedLastDimensionVectorLength>{},
number<GuaranteedLastDimensionVectorStride>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename DataType,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view_packed(DataType* p,
const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
auto desc =
make_naive_tensor_descriptor_packed(lengths, number<GuaranteedLastDimensionVectorLength>{});
auto buffer_view = make_buffer_view<BufferAddressSpace>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
}
template <typename OldTensorView,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss>
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view,
const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss)
{
auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_,
new_transforms,
NewLowerDimensionOldVisibleIdss{},
NewUpperDimensionNewVisibleIdss{});
return tensor_view<typename OldTensorView::buffer_view, remove_cvref_t<decltype(new_desc)>>{
old_tensor_view.buf_, new_desc};
}
template <typename TensorView,
typename TileLengths, // tuple<...>
typename DoPads> // sequence<bool, bool, ...>
CK_TILE_HOST_DEVICE constexpr auto
pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads)
{
constexpr index_t num_dim = DoPads::size();
static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(),
"wrong! inconsistent # of dimensions");
// transforms
const auto transforms = generate_tuple(
[&](auto idim) {
const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim);
const auto tile_length = tile_lengths[idim];
const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length;
const auto pad_length = new_length - old_length;
constexpr bool DoPad = DoPads::at(idim);
const auto transform =
conditional_expr<DoPad>(make_right_pad_transform(old_length, pad_length),
make_pass_through_transform(old_length));
return transform;
},
number<num_dim>{});
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return sequence<idim.value>{}; }, number<num_dim>{});
// upper dimension Id
const auto upper_dimss = lower_dimss;
return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// distributed span
template <index_t... PartialHsLengths>
struct tile_distributed_span
{
using Impl = sequence<PartialHsLengths...>;
static constexpr auto impl_ = Impl{};
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
// distributed index
template <index_t... PartialHsIndices>
struct tile_distributed_index
{
using Impl = sequence<PartialHsIndices...>;
static constexpr auto impl_ = Impl{};
CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
};
namespace detail {
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence<Is...>)
{
return tile_distributed_span<Is...>{};
}
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence<Is...>)
{
return tile_distributed_index<Is...>{};
}
} // namespace detail
template <typename PsYs2XsAdaptor_,
typename Ys2DDescriptor_,
typename StaticTileDistributionEncoding_,
typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
// should be more elegnat
struct tile_distribution
{
using PsYs2XsAdaptor = remove_cvref_t<PsYs2XsAdaptor_>;
using Ys2DDescriptor = remove_cvref_t<Ys2DDescriptor_>;
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
using DstrDetail = remove_cvref_t<TileDistributionDetail_>;
static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(),
"wrong! should be static");
static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
PsYs2XsAdaptor ps_ys_to_xs_;
Ys2DDescriptor ys_to_d_;
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{
#if 0
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
ps_ys_to_xs_.GetBottomDimensionLengths();
#else
return generate_tuple(
[&](auto i) {
constexpr index_t x_length =
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
return number<x_length>{};
},
number<NDimX>{});
#endif
}
CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
{
return ps_ys_to_xs_;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding()
{
return DstrEncode{};
}
#if 1
// Calculate Replication index [R0, R1, ...] based on Partion index
// FIXME: very nasty implementation
template <typename PartitionIndex>
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
{
static_assert(PartitionIndex::size() == NDimP, "wrong!");
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
array<index_t, NDimR> rs_idx;
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
static_for<0, ndim_low, 1>{}([&](auto i) {
constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
// 0-th rh_major is the replicate dimension
if constexpr(rh_major == 0)
{
constexpr index_t adaptor_hidden_id =
DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
// fill in
rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
}
});
});
return rs_idx;
}
#endif
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
return generate_tuple(
[&](auto i) {
constexpr auto span_impl = distributed_spans_impl[i];
constexpr index_t ndim_span_minor = ndims_spans_minor[i];
constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
return detail::make_tile_distributed_span(span);
},
number<NDimX>{});
}
// FIXME: it's hacky to get Y index from Distributed-Index
template <typename DistributedIndices>
CK_TILE_HOST_DEVICE static constexpr auto
get_y_indices_from_distributed_indices(DistributedIndices)
{
constexpr auto ys_idx_arr = [] {
array<index_t, NDimY> ys_idx;
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
ys_idx(i) = dstr_index.impl_[span_minor];
});
return ys_idx;
}();
constexpr index_t ndim_y = NDimY;
return TO_SEQUENCE(ys_idx_arr, ndim_y);
}
CK_TILE_HOST_DEVICE static constexpr bool is_static()
{
return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution{");
//
printf("tile_distribution_encoding: ");
print(DstrEncode{});
printf(", ");
//
printf("ps_ys_to_xs_: ");
print(ps_ys_to_xs_);
printf(", ");
//
printf("ys_to_d_: ");
print(ys_to_d_);
//
printf("}");
}
};
namespace detail {
template <index_t NDimMax>
CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend)
{
array<index_t, NDimMax> arr{0};
for(index_t i = 0; i < iend - ibegin; ++i)
{
arr(i) = ibegin + i;
}
return arr;
}
// this returns a constexpr encoding of tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto
make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
{
using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
// FIXME: increase max value if fail
constexpr index_t kMaxNumTransforms = 20;
constexpr index_t kMaxMetaDataSize = 128;
constexpr index_t kMaxNumDim = 10;
using Name = coord_transform_enum;
using MetaData = meta_data_buffer<kMaxMetaDataSize>;
using NumDim = index_t;
using Dims = array<index_t, kMaxNumDim>;
using Lengths = array<index_t, kMaxNumDim>;
// Tile Adaptor
// bottom dims [x0, x1, x2, ...]
// top dims [p0, p1, ..., y0, y1, ...]
constexpr index_t ndim_x = HsLengthss::size();
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
index_t num_tran = 0;
index_t hidden_dim_cnt = ndim_x;
// this is replicate transform
{
constexpr index_t ndim_r_minor = RsLengths::size();
constexpr auto r_minor_lengths = RsLengths{};
trans(num_tran++) = {
coord_transform_enum::replicate,
MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
NumDim{0},
Dims{},
NumDim{ndim_r_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
for(index_t i = 0; i < ndim_r_minor; ++i)
{
rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
hidden_dim_cnt++;
}
};
// these are Unmerge transforms for X dimesions
static_for<0, ndim_x, 1>{}([&trans,
&num_tran,
&hidden_dim_cnt,
&rh_major_minor_to_hidden_ids,
&rh_major_minor_to_hidden_lengths](auto idim_x) {
// typename HsLengthss::base{}.foo();
constexpr auto h_minor_lengths =
HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr index_t ndim_h_minor = h_minor_lengths.size();
trans(num_tran++) = {
coord_transform_enum::unmerge,
MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
NumDim{1},
Dims{idim_x},
NumDim{ndim_h_minor},
make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
for(index_t i = 0; i < ndim_h_minor; ++i)
{
rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
hidden_dim_cnt++;
}
});
// transform: P dimensions
constexpr index_t ndim_p = Ps2RHssMajor::size();
Dims hidden_dim_id_ps;
static_for<0, ndim_p, 1>{}([&](auto iDimP) {
//
index_t hidden_dim_id_p = hidden_dim_cnt++;
hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
constexpr index_t ndim_low = p2RHsMajor.size();
Dims low_dims;
Lengths low_lengths;
for(index_t i = 0; i < ndim_low; ++i)
{
index_t rh_major = p2RHsMajor[i];
index_t rh_minor = p2RHsMinor[i];
low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
}
trans(num_tran++) = {coord_transform_enum::merge,
MetaData{to_array<index_t, ndim_low>(low_lengths)},
NumDim{ndim_low},
low_dims,
NumDim{1},
Dims{hidden_dim_id_p}};
});
constexpr index_t ndim_bottom = ndim_x;
constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
constexpr index_t ndim_y = Ys2RHsMajor::size();
constexpr index_t ndim_top = ndim_p + ndim_y;
auto top_dim_ids = hidden_dim_id_ps;
{
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
}
}
//
const auto ps_ys_to_xs_adaptor_encoding =
make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
// descriptor: [y0, y1, ...] to [d]
Lengths y_lengths;
index_t d_length = 1;
for(index_t i = 0; i < ndim_y; ++i)
{
index_t rh_major = ys_to_rhs_major[i];
index_t rh_minor = ys_to_rhs_minor[i];
index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
y_lengths(i) = y_length;
d_length *= y_length;
}
auto tran = make_tuple(coord_transform_enum::unmerge,
MetaData{to_array<index_t, ndim_y>(y_lengths)},
NumDim{1},
Dims{0},
NumDim{ndim_y},
make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
const auto ys_to_d_adaptor_encoding = make_tuple(
make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
return make_tuple(ps_ys_to_xs_adaptor_encoding,
ys_to_d_adaptor_encoding,
d_length,
rh_major_minor_to_hidden_ids);
}
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
struct tile_distribution_detail
{
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ =
to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
};
} // namespace detail
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
// this returns a static tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
{
using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
constexpr auto adaptor_impl =
detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
constexpr index_t d_length = adaptor_impl.template at<2>();
constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
constexpr auto ps_ys_to_xs_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
constexpr auto ys_to_d_adaptor =
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
constexpr auto ys_to_d_descriptor =
make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number<d_length>{});
//
constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
constexpr auto rh_major_minor_to_hidden_ids =
TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
return tile_distribution<
remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
remove_cvref_t<decltype(ys_to_d_descriptor)>,
remove_cvref_t<DstrEncode>,
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
//***********************************************************************************
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
constexpr auto src_y_dims = src_y_info[number<0>{}];
constexpr auto src_y_maps = src_y_info[number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::detail::ys_lengths_;
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
constexpr auto sliced_h_lens = sliced_h[number<0>{}];
constexpr auto sliced_h_index = sliced_h[number<2>{}];
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
});
return y_origin_;
}();
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
src_y_prefix_sum[id + 1],
1>::type{};
set_container_subset(
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
return sliced_h_lens;
},
typename Encoding::HsLengthss{},
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
typename Encoding::Ys2RHsMinor>{}),
sliced_y_origins,
sliced_y_lengths);
}
} // namespace detail
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename RsLengths_, // sequence<...>
typename HsLengthss_, // tuple<sequence<...>, ...>
typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
typename Ys2RHsMajor_, // sequence<...>
typename Ys2RHsMinor_> // sequence<...>
struct tile_distribution_encoding
{
using RsLengths = remove_cvref_t<RsLengths_>;
using HsLengthss = remove_cvref_t<HsLengthss_>;
using Ps2RHssMajor = remove_cvref_t<Ps2RHssMajor_>;
using Ps2RHssMinor = remove_cvref_t<Ps2RHssMinor_>;
using Ys2RHsMajor = remove_cvref_t<Ys2RHsMajor_>;
using Ys2RHsMinor = remove_cvref_t<Ys2RHsMinor_>;
static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
static constexpr index_t NDimX = HsLengthss::size();
static constexpr index_t NDimP = Ps2RHssMajor::size();
static constexpr index_t NDimY = Ys2RHsMajor::size();
static constexpr index_t NDimR = RsLengths::size();
// FIXME: move into detail
static constexpr auto rs_lengths_ = RsLengths{};
static constexpr auto hs_lengthss_ = HsLengthss{};
static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct detail
{
// ndim_rh_major_, ndim_span_mainor_
static constexpr index_t ndim_rh_major_ = NDimX + 1;
static constexpr index_t ndim_span_major_ = NDimX;
// ndims_rhs_minor_[ndim_rh_major_]
static constexpr auto ndims_rhs_minor_ = generate_array(
[](auto i) {
if constexpr(i.value == 0)
{
return rs_lengths_.size();
}
else
{
return hs_lengthss_[i - number<1>{}].size();
}
},
number<ndim_rh_major_>{});
// max_ndim_rh_minor_
static constexpr index_t max_ndim_rh_minor_ =
container_reduce(ndims_rhs_minor_, maximize<index_t>{}, 0);
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_lengthss_ =
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
// ys_lengths_
static constexpr auto ys_lengths_ = [] {
array<index_t, NDimY> ys_lengths_tmp{-1};
for(index_t i = 0; i < NDimY; i++)
{
index_t rh_major = ys_to_rhs_major_[i];
index_t rh_minor = ys_to_rhs_minor_[i];
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
}
return ys_lengths_tmp;
}();
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_ys_ = [] {
array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
constexpr index_t rh_major = ys_to_rhs_major_[i];
constexpr index_t rh_minor = ys_to_rhs_minor_[i];
rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
});
return rhs_major_minor_to_ys_tmp;
}();
// ndims_span_minor_[NDimY]
static constexpr auto ndims_span_minor_ = [] {
array<index_t, NDimX> ndims_span_minor{0};
for(index_t i = 0; i < NDimY; i++)
{
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_span_minor(span_major)++;
}
return ndims_span_minor;
}();
// max_ndim_span_minor_
static constexpr index_t max_ndim_span_minor_ =
container_reduce(ndims_span_minor_, maximize<index_t>{}, 0);
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_span_minor_ = [] {
array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
{-1}};
static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
index_t cnt_ndim_span_minor = 0;
static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
if(idim_y >= 0)
{
rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
cnt_ndim_span_minor++;
}
});
});
return rhs_major_minor_to_span_minor;
}();
// ys_to_span_major_[NDimY]
static constexpr auto ys_to_span_major_ =
generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
// ys_to_span_minor_[NDimY]
static constexpr auto ys_to_span_minor_ = generate_array(
[](auto i) {
return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]];
},
number<NDimY>{});
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
static constexpr auto distributed_spans_lengthss_ = [] {
array<array<index_t, max_ndim_span_minor_>, ndim_span_major_>
distributed_spans_lengthss{{-1}};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t rh_major = ys_to_rhs_major_[i];
const index_t rh_minor = ys_to_rhs_minor_[i];
const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
const index_t span_major = rh_major - 1;
const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
distributed_spans_lengthss(span_major)(span_minor) = h_length;
});
return distributed_spans_lengthss;
}();
// ndims_distributed_spans_minor_[ndim_span_major_]
static constexpr auto ndims_distributed_spans_minor_ = [] {
array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
static_for<0, NDimY, 1>{}([&](auto i) {
const index_t span_major = ys_to_rhs_major_[i] - 1;
ndims_distributed_spans_minor(span_major)++;
});
return ndims_distributed_spans_minor;
}();
// does_p_own_r_[NDimP][NDimR]
static constexpr auto does_p_own_r_ = [] {
if constexpr(NDimR > 0)
{
array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
if constexpr(rh_major == 0)
{
does_p_own_r(idim_p)(rh_minor) = true;
}
});
});
return does_p_own_r;
}
else
{
return array<array<bool, NDimR>, NDimP>{};
}
}();
// ps_over_rs_derivative_[NDimP][NDimR]
static constexpr auto ps_over_rs_derivative_ = [] {
if constexpr(NDimR > 0)
{
array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
static_for<0, NDimP, 1>{}([&](auto idim_p) {
constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
index_t p_over_rh_derivative = 1;
static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
if constexpr(rh_major == 0)
{
ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
}
p_over_rh_derivative *= rh_length;
});
});
return ps_over_rs_derivative;
}
else
{
return array<array<index_t, NDimR>, NDimP>{};
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
[&](auto i) {
constexpr index_t size = HsLengthss{}[i].size();
return number<size>{};
},
number<NDimX>{});
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
return h_dim_prefix_sum;
}
CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h()
{
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr auto x_dim_prefix_sum = merge_sequences(
sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
return x_dim_prefix_sum.at(major) + minor;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
return all_ys_2_rhss;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template <typename IdxSeq, typename PrefixSumSeq>
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
{
using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
constexpr auto sorted_dims = typename sorted_idx::type{};
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
constexpr auto sorted_histogram =
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
}
CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
{
return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum());
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding::detail{");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("ndim_span_major_: ");
print(ndim_span_major_);
printf(", ");
//
printf("ndims_rhs_minor_: ");
print(ndims_rhs_minor_);
printf(", ");
//
printf("ndim_rh_major_: ");
print(ndim_rh_major_);
printf(", ");
//
printf("max_ndim_rh_minor_: ");
print(max_ndim_rh_minor_);
printf(", ");
//
printf("rhs_lengthss_: ");
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
//
printf("ndims_span_minor_: ");
print(ndims_span_minor_);
printf(", ");
//
printf("max_ndim_span_minor_: ");
print(max_ndim_span_minor_);
printf(", ");
//
printf("ys_to_span_major_: ");
print(ys_to_span_major_);
printf(", ");
//
printf("ys_to_span_minor_: ");
print(ys_to_span_minor_);
printf(", ");
//
printf("distributed_spans_lengthss_: ");
print(distributed_spans_lengthss_);
printf(", ");
//
printf("ndims_distributed_spans_minor_: ");
print(ndims_distributed_spans_minor_);
printf(", ");
//
printf("ps_over_rs_derivative_: ");
print(ps_over_rs_derivative_);
//
printf("}");
}
};
CK_TILE_HOST_DEVICE void print() const
{
printf("tile_distribution_encoding{");
//
printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
//
printf("rs_lengths_: ");
print(rs_lengths_);
printf(", ");
//
printf("hs_lengthss_: ");
print(hs_lengthss_);
printf(", ");
//
printf("ps_to_rhss_major_: ");
print(ps_to_rhss_major_);
printf(", ");
//
printf("ps_to_rhss_minor_: ");
print(ps_to_rhss_minor_);
printf(", ");
//
printf("ys_to_rhs_major_: ");
print(ys_to_rhs_major_);
printf(", ");
//
printf("ys_to_rhs_minor_: ");
print(ys_to_rhs_minor_);
printf(", ");
//
printf("detail: ");
print(detail{});
//
printf("}");
}
};
namespace detail {
template <typename OuterDstr, typename InnerDstr>
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
{
static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
constexpr index_t NDimHMajor = OuterDstr::NDimX;
using RsLengths =
sequence_merge_t<typename OuterDstr::RsLengths, typename InnerDstr::RsLengths>;
constexpr auto hs_lengthss = generate_tuple(
[&](auto i) {
return merge_sequences(typename OuterDstr::HsLengthss{}[i],
typename InnerDstr::HsLengthss{}[i]);
},
number<NDimHMajor>{});
//
constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
// R dimension
rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
// Hs dimensions
static_for<0, NDimHMajor, 1>{}([&](auto i) {
rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
});
return rhs_major_2_ndim_outer_rhs_minor_;
}();
// Ps2RHssMinor
constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
[&](auto p) {
constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
constexpr auto updated_inner_p_2_rhss_minor = [&]() {
array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_p_2_rhss_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
}
return updated_inner_p_2_rhss_minor_;
}();
return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
},
number<InnerDstr::NDimP>{});
// Ys2RHsMinor
constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
for(index_t i = 0; i < ndim_tmp; i++)
{
index_t rh_major = inner_ys_2_rhs_major[i];
index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
}
return updated_inner_ys_2_rhs_minor__;
}();
return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
}();
//
constexpr auto ps_2_rhss_major =
container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
constexpr auto ps_2_rhss_minor =
container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
//
constexpr auto ys_2_rhs_major =
merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
constexpr auto ys_2_rhs_minor =
merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
return tile_distribution_encoding<RsLengths,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_2_rhss_major)>,
remove_cvref_t<decltype(ps_2_rhss_minor)>,
remove_cvref_t<decltype(ys_2_rhs_major)>,
remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
}
template <typename InDstr, index_t... InReduceDimXs>
CK_TILE_HOST_DEVICE constexpr auto
make_reduce_tile_distribution_encoding_impl(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto I1 = number<1>{};
// FIXME: increase if fail
constexpr index_t max_ndim_r_out = 20;
constexpr index_t max_ndim_y_out = 20;
//
constexpr index_t ndim_p = InDstr::NDimP;
constexpr index_t ndim_x_in = InDstr::NDimX;
constexpr index_t ndim_y_in = InDstr::NDimY;
constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
// ndims_ps_low
constexpr auto ndims_ps_low = generate_array(
[&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
// is_rh_major_in_for_reduce
array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
{
index_t rh_major = reduce_dim_xs_in[i] + 1;
is_rh_major_in_for_reduce(rh_major) = true;
}
// is_y_in_for_reduce
array<bool, ndim_y_in> is_y_in_for_reduce{false};
for(index_t i = 0; i < ndim_y_in; i++)
{
index_t rh_major = InDstr::ys_to_rhs_major_[i];
if(is_rh_major_in_for_reduce[rh_major])
{
is_y_in_for_reduce(i) = true;
}
}
// is_rh_minor_in_for_y_reduce
array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
static_for<0, ndim_y_in, 1>{}([&](auto i) {
index_t rh_major = InDstr::ys_to_rhs_major_[i];
index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
if(is_y_in_for_reduce[i])
{
is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
}
});
// in2out_rh_major
array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
index_t cnt_ndim_rh_major_out = 0;
for(index_t i = 0; i < ndim_rh_major_in; i++)
{
if(is_rh_major_in_for_reduce[i])
{
in2out_rh_major(i) = 0;
}
else
{
in2out_rh_major(i) = cnt_ndim_rh_major_out;
cnt_ndim_rh_major_out++;
}
}
// rs_lengths_out, in2out_rh_minor
array<index_t, max_ndim_r_out> rs_lengths_out{-1};
array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
// loop over input R dim
for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
{
// rs_lengths_out
rs_lengths_out(i) = InDstr::rs_lengths_[i];
// in2out_rh_minor
in2out_rh_minor(0)(i) = i;
}
// loop over input H Dim
index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
constexpr auto h_major_in = rh_major_in - I1;
constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
if(is_rh_major_in_for_reduce[rh_major_in])
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
{
// rs_lengths_out
rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
cnt_ndim_r_out++;
}
}
}
else
{
for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
{
// in2out_rh_minor
in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
}
}
});
// ndim_r_out
const index_t ndim_r_out = cnt_ndim_r_out;
// ndims_hs_minor_out, hs_lengthss_out
array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
index_t cnt_ndim_x_out = 0;
static_for<0, ndim_x_in, 1>{}([&](auto i) {
if(not is_rh_major_in_for_reduce[i + I1])
{
// ndims_hs_minor_out
ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
// hs_lengthss_out
static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
[&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
cnt_ndim_x_out++;
}
});
// ps_to_rhss_major_out, ps_to_rhss_minor_out
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
static_for<0, ndim_p, 1>{}([&](auto idim_p) {
static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
});
});
// ys_to_rhs_major_out, ys_to_rhs_minor_out
array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
index_t cnt_ndim_y_out = 0;
static_for<0, ndim_y_in, 1>{}([&](auto i) {
if(not is_y_in_for_reduce[i])
{
index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
cnt_ndim_y_out++;
}
});
// ndim_y_out
const index_t ndim_y_out = cnt_ndim_y_out;
//
return make_tuple(ndim_x_out,
ndim_p,
ndim_y_out,
ndim_r_out,
ndims_hs_minor_out,
ndims_ps_low,
rs_lengths_out,
hs_lengthss_out,
ps_to_rhss_major_out,
ps_to_rhss_minor_out,
ys_to_rhs_major_out,
ys_to_rhs_minor_out);
}
template <typename InDstr, index_t... InReduceDimXs>
CK_TILE_HOST_DEVICE constexpr auto
make_reduce_tile_distribution_encoding(InDstr, sequence<InReduceDimXs...> reduce_dim_xs_in)
{
constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
constexpr index_t ndim_x = impl.template at<0>();
constexpr index_t ndim_p = impl.template at<1>();
constexpr index_t ndim_y = impl.template at<2>();
constexpr index_t ndim_r = impl.template at<3>();
constexpr auto ndims_hs_minor = impl.template at<4>();
constexpr auto ndims_ps_low = impl.template at<5>();
constexpr auto rs_lengths_impl = impl.template at<6>();
constexpr auto hs_lengthss_impl = impl.template at<7>();
constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
constexpr auto ps_to_rhss_major =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
constexpr auto ps_to_rhss_minor =
TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
remove_cvref_t<decltype(hs_lengthss)>,
remove_cvref_t<decltype(ps_to_rhss_major)>,
remove_cvref_t<decltype(ps_to_rhss_minor)>,
remove_cvref_t<decltype(ys_to_rhs_major)>,
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
}
} // namespace detail
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment