/*************************************************************************************************** * OPUS, AI (O)(P)erator Micro(U) (S)TD * * Crafting the micro standard templates for AI Operators on ROCm * * MIT License * Copyright (C) 2025-2026 carlus.huang@amd.com * **************************************************************************************************/ #pragma once // clang-format off #include #include #ifndef OPUS_ENABLE_RUNTIME_QUERY #define OPUS_ENABLE_RUNTIME_QUERY 0 #endif #if OPUS_ENABLE_RUNTIME_QUERY && defined(__HIPCC__) && !defined(__HIP_DEVICE_COMPILE__) #include #endif #ifdef __HIPCC__ #define OPUS_H inline __host__ #define OPUS_D __device__ #define OPUS_H_D inline __host__ __device__ #define OPUS_D_EXTERN __device__ #define OPUS_H_D_EXTERN __host__ __device__ #else #define OPUS_H inline #define OPUS_D #define OPUS_H_D inline #define OPUS_D_EXTERN #define OPUS_H_D_EXTERN #endif #ifndef OPUS_FP32_to_BF16_DEFAULT #define OPUS_FP32_to_BF16_DEFAULT 2 // truncate, valid on gfx94* and before #endif #ifndef OPUS_TILE_CONTAINER #define OPUS_TILE_CONTAINER 0 // 0:vector, 1:array of vector, 2:flattened array #endif namespace opus { ///////////////////////////////////////////////////////////////////////////////////////////////////////// // type traits using std::remove_reference; using std::remove_reference_t; using std::remove_cv; using std::remove_cv_t; using std::is_same; using std::is_same_v; template struct remove_cvref { using type = remove_cv_t>; }; template using remove_cvref_t = remove_cv_t>; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // constant using index_t = int; using long_index_t = long long; template struct number : public std::integral_constant {}; template struct bool_constant : public std::bool_constant {}; typedef bool_constant true_type; typedef bool_constant false_type; template struct is_constant : public false_type {}; template struct is_constant> : true_type {}; template struct is_constant> : true_type {}; template struct is_constant> : true_type {}; template static constexpr bool is_constant_v = is_constant>::value; // prefer use this // using opus::operator""_I; // => add this in your code to utilize the literal cast, e.g. 2_I, 3_I template OPUS_H_D constexpr decltype(auto) operator""_I() { constexpr auto to_number_ = []() { index_t v = 0; ((v = v * 10 + (Ds - '0')), ...); return v; }; return number{}; } #define OPUS_LEFT_UNARY_OP(OP) template OPUS_H_D constexpr auto operator OP(number) { return number<(OP x)>{}; } #define OPUS_BINARY_OP(OP) template OPUS_H_D constexpr auto operator OP(number, number) { return number<(x OP y)>{}; } OPUS_LEFT_UNARY_OP(+) OPUS_LEFT_UNARY_OP(-) OPUS_LEFT_UNARY_OP(~) OPUS_LEFT_UNARY_OP(!) OPUS_BINARY_OP(+) OPUS_BINARY_OP(-) OPUS_BINARY_OP(*) OPUS_BINARY_OP(/) OPUS_BINARY_OP(%) OPUS_BINARY_OP(&) OPUS_BINARY_OP(|) OPUS_BINARY_OP(^) OPUS_BINARY_OP(<<) OPUS_BINARY_OP(>>) OPUS_BINARY_OP(&&) OPUS_BINARY_OP(||) OPUS_BINARY_OP(==) OPUS_BINARY_OP(!=) OPUS_BINARY_OP(>) OPUS_BINARY_OP(<) OPUS_BINARY_OP(>=) OPUS_BINARY_OP(<=) #undef OPUS_LEFT_UNARY_OP #undef OPUS_BINARY_OP template constexpr bool is_any_of() noexcept { return (std::is_same_v || ...); } template static constexpr bool is_any_of_v = is_any_of(); ///////////////////////////////////////////////////////////////////////////////////////////////////////// // underscore, useful struture to mock struct underscore { /*who am I*/ }; static constexpr underscore _; template struct is_underscore : false_type {}; template <> struct is_underscore : true_type {}; template static constexpr bool is_underscore_v = is_underscore::value; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // constexpr functional math struct plus { template OPUS_H_D constexpr decltype(auto) operator()(X a, Y b) const { return a + b; } }; struct minus { template OPUS_H_D constexpr decltype(auto) operator()(X a, Y b) const { return a - b; } }; struct multiplies { template OPUS_H_D constexpr decltype(auto) operator()(X a, Y b) const { return a * b; } }; struct divides { template OPUS_H_D constexpr decltype(auto) operator()(X a, Y b) const { return a / b; } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // seq template class seq { public: using value_type = index_t; OPUS_H_D static constexpr index_t size() { return sizeof...(Is);} OPUS_H_D constexpr value_type operator[](index_t i) const { return data[i]; } OPUS_H_D static constexpr value_type at(index_t i) { return data[i]; } template OPUS_H_D static constexpr value_type at() { return data[I]; } template OPUS_H_D static constexpr value_type at(number) { return data[I]; } private: static constexpr value_type data[sizeof...(Is) + 1] = {Is..., value_type{}}; }; template OPUS_H_D constexpr auto seq_pop_front(seq) { return seq{}; } template OPUS_H_D constexpr decltype(auto) get(seqconst& ) { static_assert(I < sizeof...(Is)); return seq::at(number{}); } template OPUS_H_D constexpr decltype(auto) get(seq& ) { static_assert(I < sizeof...(Is)); return seq::at(number{}); } template OPUS_H_D constexpr decltype(auto) get(seq&& ) { static_assert(I < sizeof...(Is)); return seq::at(number{}); } namespace impl { template struct __integer_sequence; template struct __integer_sequence { using seq_type = seq; }; template struct __steped_integer_seq; template struct __steped_integer_seq> { using seq_type = seq<(Start + Is * Step) ... >; }; template struct __make_index_seq; template struct __make_index_seq> { using seq_type = typename __make_integer_seq<__integer_sequence, index_t, N>::seq_type; }; template struct __make_index_seq> { using seq_type = typename __steped_integer_seq >::seq_type>::seq_type; }; template struct __make_index_seq> { using seq_type = typename __steped_integer_seq >::seq_type>::seq_type; }; } // namespace impl // make_index_seq<5> -> seq<0,1,2,3,4> | make_index_seq<4, 9> -> seq<4,5,6,7,8> | make_index_seq<4, 8, 2> -> seq<4, 6> template using make_index_seq = typename impl::__make_index_seq>::seq_type; namespace impl { template struct __make_repeated_seq { template static constexpr auto __make(seq) { return seq<(void(I), Value)...>{}; } using seq_type = decltype(__make(make_index_seq{})); }; } // namespace impl template using make_repeated_seq = typename impl::__make_repeated_seq::seq_type; template OPUS_H_D constexpr auto concat_seq(seq, seq) { return seq{}; } namespace impl { template struct reduce_seq_impl; template struct reduce_seq_impl> { using type = typename reduce_seq_impl>::type; }; template struct reduce_seq_impl> { using type = seq; }; template struct reduce_seq_impl> { using type = seq<>; }; } template OPUS_H_D constexpr auto reduce_seq(seq) { return typename impl::reduce_seq_impl>::type{}; } template OPUS_H_D constexpr auto reduce_seq_sum(seq) { if constexpr (sizeof...(Xs) == 0) return seq<>{}; else return seq<(Xs + ...)>{}; } template OPUS_H_D constexpr auto reduce_seq_mul(seq) { if constexpr (sizeof...(Xs) == 0) return seq<>{}; else return seq<(Xs * ...)>{}; } template struct is_seq : false_type {}; template struct is_seq> : true_type {}; template constexpr bool is_seq_v = is_seq>::value; template OPUS_H_D constexpr std::enable_if_t, index_t> size(T&&) { return remove_cvref_t::size(); /* tuple size */} template OPUS_H_D constexpr std::enable_if_t, index_t> size() { return remove_cvref_t::size(); /* tuple size */} template , bool> = true> OPUS_H_D constexpr decltype(auto) get(T const& t) { static_assert(I < T::size()); return t[number{}]; } template , bool> = true> OPUS_H_D constexpr decltype(auto) get(T& t) { static_assert(I < T::size()); return t[number{}]; } template , bool> = true> OPUS_H_D constexpr decltype(auto) get(T&& t) { static_assert(I < T::size()); return t[number{}]; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // functional namespace impl { template struct static_for_impl; template struct static_for_impl> { template OPUS_H_D constexpr void operator()(F&& f) const { (f(number{}), ...); } }; } // namespace impl template OPUS_H_D constexpr void static_for(F f) { impl::static_for_impl>{}(f); } template && ...), bool> = true> OPUS_H_D constexpr void static_for(F f, R...) { impl::static_for_impl>{}(f); } namespace impl { // Flat static_ford: single-level static_for, non-recursive compile-time index decomposition via fold expressions template constexpr index_t ford_stride(seq, seq) { return ((Is > D ? Ns : index_t(1)) * ... * index_t(1)); } template constexpr index_t ford_dim(seq, seq) { return ((Is == D ? Ns : index_t(1)) * ...); } template constexpr index_t ford_at() { return (I / ford_stride(make_index_seq{}, seq{})) % ford_dim(make_index_seq{}, seq{}); } template struct static_ford_impl; template struct static_ford_impl> { template OPUS_H_D static constexpr void call_one(F& f, number, seq) { f(number()>{}...); } template OPUS_H_D constexpr void operator()(F f) const { static_for<(Ns * ... * 1)>([&](auto I) { call_one(f, I, make_index_seq{}); }); } }; template <> struct static_ford_impl> { template OPUS_H_D constexpr void operator()(F f) const { f(); } }; } template OPUS_H_D constexpr void static_ford(F f) { impl::static_ford_impl>{}(f); } template OPUS_H_D constexpr void static_ford(seq, F f) { impl::static_ford_impl>{}(f); } template struct tuple; template OPUS_H_D constexpr void static_ford(tuple...>, F f) { impl::static_ford_impl>{}(f); } template struct get_value_type { using type = remove_cvref_t; }; template using get_value_t = typename get_value_type::type; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // array, enhanced C like array style template struct array { using value_type = remove_cvref_t; using type = array; #if 0 // don't define following, just let me be trivially copyable class OPUS_H_D constexpr array() = default; OPUS_H_D constexpr array(const type& o) { static_for([&](auto i){ content[i.value] = o[i.value]; }); } OPUS_H_D constexpr type& operator=(const type o) { static_for([&](auto i){ content[i.value] = o[i.value]; }); return *this; } template, value_type> && ...), bool> = true> OPUS_H_D constexpr array(Z&&... zs) : content{zs...} { /* used for make_array */ } #endif OPUS_H_D constexpr value_type& operator[](index_t pos) { return content[pos]; } OPUS_H_D constexpr const value_type& operator[](index_t pos) const { return content[pos]; } template OPUS_H_D constexpr value_type& operator[](number) { return content[I]; } template OPUS_H_D constexpr const value_type& operator[](number) const { return content[I]; } OPUS_H_D constexpr void fill(const T& value) { static_for([&](auto i){ content[i.value] = value; }); } OPUS_H_D constexpr void clear() { fill(static_cast(0)); } OPUS_H_D static constexpr bool empty() { return size() == 0; } OPUS_H_D static constexpr index_t size() { return N; } // we need this "content" member to have a default value, so that the implicitly defined constructor could be constexpr // see: https://en.cppreference.com/w/cpp/language/constexpr.html#constexpr_constructor value_type content[N] {}; }; template OPUS_H_D constexpr bool operator==(const array& x, const array& y) { for (index_t i = 0; i < N; ++i) { if (x[i] != y[i]) { return false; } } return true; } template OPUS_H_D constexpr void clear(array& a) { a.clear(); } template OPUS_H_D constexpr void fill(array& a, T const& value) { a.fill(value); } template struct is_array : false_type {}; template struct is_array> : true_type {}; template constexpr bool is_array_v = is_array>::value; template struct get_value_type>> { using type = typename T::value_type; }; namespace impl { template struct is_ref_wrapper : std::false_type{}; template struct is_ref_wrapper> : std::true_type{}; template using not_ref_wrapper = std::negation>>; template struct array_return_type_helper { using type = D; }; template struct array_return_type_helper : std::common_type { static_assert(std::conjunction_v...>, "Types cannot contain reference_wrappers when D is void"); }; template using array_return_type = opus::array::type, sizeof...(Types)>; } template OPUS_H_D constexpr impl::array_return_type make_array(Types&&... t) { return {std::forward(t)...}; } template , bool> = true> OPUS_H_D constexpr decltype(auto) get(T const& t) { static_assert(I < T::size()); return t[number{}]; } template , bool> = true> OPUS_H_D constexpr decltype(auto) get(T& t) { static_assert(I < T::size()); return t[number{}]; } template , bool> = true> OPUS_H_D constexpr decltype(auto) get(T&& t) { static_assert(I < T::size()); return t[number{}]; } namespace impl { template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1, seq, seq) { return opus::make_array(get(t0)..., get(t1)...); } template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1, T2 const& t2, seq, seq, seq) { return opus::make_array(get(t0)..., get(t1)..., get(t2)...); } template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, seq, seq, seq, seq) { return opus::make_array(get(t0)..., get(t1)..., get(t2)..., get(t3)...); } } template OPUS_H_D constexpr auto concat_array(T0 const& t0) { return t0; } template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1) { return impl::concat_array(t0, t1, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1, T2 const& t2) { return impl::concat_array(t0, t1, t2, make_index_seq{}, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) { return impl::concat_array(t0, t1, t2, t3, make_index_seq{}, make_index_seq{}, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_array(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, Ts const&... ts) { return concat_array(concat_array(t0, t1, t2, t3), concat_array(t4, ts...)); } template OPUS_H_D constexpr std::enable_if_t, index_t> size(T&&) { return remove_cvref_t::size(); /* tuple size */} template OPUS_H_D constexpr std::enable_if_t, index_t> size() { return remove_cvref_t::size(); /* tuple size */} ///////////////////////////////////////////////////////////////////////////////////////////////////////// // tuple namespace impl { template || std::is_void_v)> struct tuple_object {}; // the place where content is stored template struct tuple_object { OPUS_H_D constexpr tuple_object() {} template , tuple_object>::value, bool>::type = false> OPUS_H_D constexpr tuple_object(U&&) {} }; template struct tuple_object { OPUS_H_D constexpr tuple_object() : element{} {} template , tuple_object>::value, bool>::type = false> OPUS_H_D constexpr tuple_object(U&& e) : element(std::forward(e)) {} T element; }; // NOTE: we return a instance(not a reference) if content is empty template OPUS_H_D constexpr T getv(const tuple_object&) { return {}; } template OPUS_H_D constexpr const T& getv(const tuple_object& x) { return x.element; } template OPUS_H_D constexpr T& getv(tuple_object& x) { return x.element; } template OPUS_H_D constexpr T&& getv(tuple_object&& x) { return static_cast(x.element); } template struct tuple_base; template struct tuple_base, T...> : tuple_object... { OPUS_H_D constexpr tuple_base() = default; template , tuple_base>::value, bool>::type = false> OPUS_H_D constexpr tuple_base(U&& u) : tuple_object(std::forward(u))... {} template = 2, bool>::type = false> OPUS_H_D constexpr tuple_base(U&&... u) : tuple_object(std::forward(u))... { static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U), "wrong!"); } }; } // namespace impl template struct tuple : impl::tuple_base, T...> { OPUS_H_D static constexpr index_t size() { return sizeof...(T); } using base = impl::tuple_base, T...>; OPUS_H_D constexpr tuple() = default; template , tuple>::value, bool>::type = false> OPUS_H_D constexpr tuple(U&& u) : base(std::forward(u)) {} template = 2, bool>::type = false> OPUS_H_D constexpr tuple(U&&... u) : base(std::forward(u)...) {} }; template __host__ __device__ tuple(T&&...) -> tuple...>; namespace impl { template struct tuple_array_helper; template struct tuple_array_helper> { using type = tuple{}))...>; }; } template using tuple_array = typename impl::tuple_array_helper>::type; // alias for tuple, Nx Ts // get the I-th type within the tuple, O(1) via compiler intrinsic template struct tuple_element; template struct tuple_element> { using type = __type_pack_element; }; template using tuple_element_t = typename tuple_element::type; template OPUS_H_D constexpr decltype(auto) get(tuple const& t) { static_assert(I < sizeof...(T)); return impl::getv(t); } template OPUS_H_D constexpr decltype(auto) get(tuple& t) { static_assert(I < sizeof...(T)); return impl::getv(t); } template OPUS_H_D constexpr decltype(auto) get(tuple&& t) { static_assert(I < sizeof...(T)); return impl::getv(std::move(t)); } template /*recursive get*/ OPUS_H_D constexpr decltype(auto) get(T&& t) { return get(get(std::move(t))); } template OPUS_H_D constexpr auto make_tuple(T&&... xs) { return tuple...>(std::forward(xs)...); } template && ...), bool> = true> // const integer based static_for loop OPUS_H_D constexpr void static_for(F f, R... range) { if constexpr (sizeof...(range) == 1) { auto end = get<0>(make_tuple(range...)); for(index_t i = 0; i < end; i++) { f(i); } } else if constexpr (sizeof...(range) == 2) { auto [start, end] = make_tuple(range...); for(index_t i = start; i < end; i++) { f(i); } } else if constexpr (sizeof...(range) == 3) { auto [start, end, step] = make_tuple(range...); for(index_t i = start; i < end; i += step) { f(i); } } } namespace impl { template OPUS_H_D constexpr auto make_repeated_tuple(T&& x, seq) { return opus::make_tuple((void(Is), std::forward(x))...); } } // namespace impl template OPUS_H_D constexpr auto make_repeated_tuple(T&& x) { return impl::make_repeated_tuple(std::forward(x), make_index_seq{}); } template OPUS_H_D constexpr auto make_repeated_tuple(T&& x, number) { return impl::make_repeated_tuple(std::forward(x), make_index_seq{}); } namespace impl { template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, seq, seq) { return opus::make_tuple(get(t0)..., get(t1)...); } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2, seq, seq, seq) { return opus::make_tuple(get(t0)..., get(t1)..., get(t2)...); } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, seq, seq, seq, seq) { return opus::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); } } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0) { return t0; } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1) { return impl::concat_tuple(t0, t1, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2) { return impl::concat_tuple(t0, t1, t2, make_index_seq{}, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) { return impl::concat_tuple(t0, t1, t2, t3, make_index_seq{}, make_index_seq{}, make_index_seq{}, make_index_seq{}); } namespace impl { template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, seq, seq, seq, seq, seq) { return opus::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); } } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) { return impl::concat_tuple(t0, t1, t2, t3, t4, make_index_seq{}, make_index_seq{}, make_index_seq{}, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_tuple(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) { return concat_tuple(concat_tuple(t0, t1, t2, t3, t4), concat_tuple(t5, ts...)); } template struct is_tuple : false_type {}; template struct is_tuple> : true_type {}; template static constexpr bool is_tuple_v = is_tuple>::value; template struct is_static_tuple : is_constant> {}; template <> struct is_static_tuple : true_type {}; template struct is_static_tuple> : bool_constant<(is_static_tuple::value && ...)> {}; template static constexpr bool is_static_tuple_v = is_static_tuple>::value; template struct get_value_type>> { using type = tuple_element_t<0, T>; }; // TODO: get the first element type template OPUS_H_D constexpr std::enable_if_t, index_t> size(T&&) { return remove_cvref_t::size(); /* tuple size */} template OPUS_H_D constexpr std::enable_if_t, index_t> size() { return remove_cvref_t::size(); /* tuple size */} template , bool> = true> OPUS_H_D constexpr auto explode_tuple(const T& t) { return opus::make_tuple(t); } template OPUS_H_D constexpr auto explode_tuple(const T&, seq); template , bool> = true> OPUS_H_D constexpr auto explode_tuple(const T& t) { return explode_tuple(t, make_index_seq()>{}); } template OPUS_H_D constexpr auto explode_tuple(const T& t, seq) { return concat_tuple(explode_tuple(get(t))...); } template OPUS_H_D constexpr auto flatten_tuple_general(const T& t, seq) { return concat_tuple(explode_tuple(get(t))...); } template && !(is_tuple_v>>), bool> = true> OPUS_H_D constexpr auto flatten_tuple(const T& t) { return t; } // already flat template , bool> = true> OPUS_H_D constexpr auto flatten_tuple(const T& t) { return flatten_tuple_general(t, make_index_seq()>{}); } // non-tuple (e.g. seq) namespace impl { // direct flatten for 1-level nested tuples — bypasses concat_tuple + explode_tuple template constexpr auto group_sizes(seq) { return seq>()...>{}; } template constexpr index_t group_total(seq) { return (size>() + ...); } template constexpr index_t flat_group(seq, seq) { index_t acc = 0, r = 0; ((void)(acc += Ns, (acc <= J ? (void)(r = Gs + 1) : (void)0)), ...); return r; } template constexpr index_t group_offset(seq) { return ((Gs < G ? size>() : 0) + ...); } template OPUS_H_D constexpr auto flatten_at(const T& t) { constexpr auto gs = make_index_seq()>{}; constexpr index_t G = flat_group(gs, GS{}); return get(gs)>(get(t)); } template OPUS_H_D constexpr auto flatten_tuple_impl(const T& t, seq) { return opus::make_tuple(flatten_at(t)...); } } template && (is_tuple_v>>), bool> = true> OPUS_H_D constexpr auto flatten_tuple(const T& t) { using U = remove_cvref_t; constexpr auto gs = make_index_seq()>{}; return impl::flatten_tuple_impl(gs))>(t, make_index_seq(gs)>{}); } namespace impl { template OPUS_H_D constexpr auto embed_nested_tuple_impl(const Outer& ot, const Inner& it, seq) { return opus::make_tuple(concat_tuple(get(ot), get(it))...); } template OPUS_H_D constexpr auto tuple_count_impl(seq) { return (number(T{}))>, remove_cvref_t> ? 1 : 0>{} + ...); } } // Outer: tuple, tuple>, Inner: tuple, tuple> => tuple, tuple> template OPUS_H_D constexpr auto embed_nested_tuple(const Outer& ot, const Inner& it) { static_assert(size() == size()); return impl::embed_nested_tuple_impl(ot, it, make_index_seq()>{}); } template< typename TargetType, typename T, std::enable_if_t, bool> = true> OPUS_H_D constexpr index_t tuple_count(const T& /*t*/) { return impl::tuple_count_impl>(make_index_seq()>{}).value; } template< typename TargetType, typename T, std::enable_if_t, bool> = true> OPUS_H_D constexpr index_t tuple_count() { return impl::tuple_count_impl>(make_index_seq()>{}).value; } template OPUS_H_D constexpr auto seq_to_tuple(seq) { return opus::make_tuple(number{}...); } template OPUS_H_D constexpr auto to_tuple(seq) { return opus::make_tuple(number{}...); } template, bool> = true> OPUS_H_D constexpr auto to_tuple(const T& t) { return t; } namespace impl { template OPUS_H_D constexpr auto reduce_tuple_impl(const T& t, seq<>) { return t; } template OPUS_H_D constexpr auto reduce_tuple_impl(const T& t, seq) { return t; } template OPUS_H_D constexpr auto reduce_tuple_impl(const T& t, seq) { return reduce_tuple_impl(opus::make_tuple(R{}(get(t), get(t)), get(t)...), make_index_seq{}); } } template, bool> = true> OPUS_H_D constexpr auto reduce_tuple(const T & t) { return impl::reduce_tuple_impl(t, make_index_seq()>{}); } template, bool> = true> OPUS_H_D constexpr auto reduce_tuple_sum(const T & t) { return reduce_tuple(t); } template, bool> = true> OPUS_H_D constexpr auto reduce_tuple_mul(const T & t) { return reduce_tuple(t); } // Fast path: fold expression for tuple of number<> types (avoids N-1 intermediate tuple types) template && ...), bool> = true> OPUS_H_D constexpr auto reduce_tuple_mul(const tuple&) { return opus::tuple>{}; } namespace impl { template OPUS_H_D constexpr index_t underscore_count_in(seq) { return ((is_underscore_v(PT{}))>> ? 1 : 0) + ... + 0); } template OPUS_H_D constexpr index_t peephole_idx() { constexpr index_t c = underscore_count_in(make_index_seq{}); return c < MaxN::value ? c : MaxN::value - 1; } template OPUS_H_D constexpr auto to_peepholed_seq_impl(seq) { return seq()...>{}; } template OPUS_H_D constexpr decltype(auto) merge_peepholed_tuple_impl(PeepholedTuple&& pt, IncomTuple&& it, seq, seq) { return opus::make_tuple([&](){ if constexpr (is_underscore_v(pt))>>) return get(it); else return get(pt);}()... ); } } // (Peepholed)tuple<*, *, _, *, _> + (Income)tuple<#, @> -> tuple<*, *, #, *, @>. "_"(underscore) indicate a peephole for income tuple to chime in template OPUS_H_D constexpr decltype(auto) merge_peepholed_tuple(PeepholedTuple&& pt, IncomeTuple&& it) { if constexpr (tuple_count() == 0) return pt; else { constexpr auto income_seq = impl::to_peepholed_seq_impl< remove_cvref_t, number()> >(make_index_seq()>{}); return impl::merge_peepholed_tuple_impl(std::forward(pt), std::forward(it), make_index_seq()>{}, income_seq); } } } // namespace opus // implementing the "tuple-like binding protocol", don't use below directly namespace std { template struct tuple_size> : std::integral_constant {}; template struct tuple_size> : std::integral_constant {}; template struct tuple_element> { using type = __type_pack_element; }; template struct tuple_element> { using type = const __type_pack_element; }; } // namespace std namespace opus { ///////////////////////////////////////////////////////////////////////////////////////////////////////// // transforms template constexpr auto embed(const X& x, const Y& y, seq) { return ( ... + (get(x) * get(y))); } template constexpr auto embed(const X& x, const Y& y) { return embed(x, y, make_index_seq{}); } namespace impl { template OPUS_H_D constexpr auto transform_tuple_impl(F f, const X& x, seq) { return opus::make_tuple(f(get(x))...); } template OPUS_H_D constexpr auto transform_tuple_with_idx_impl(F f, const X& x, seq) { return opus::make_tuple(f(get(x), number{})...); } } // namespace impl // f(auto item) template OPUS_H_D constexpr auto transform_tuple(F f, const X& x) { return impl::transform_tuple_impl(f, x, make_index_seq()>{}); } // f(auto item, auto index) template OPUS_H_D constexpr auto transform_tuple_with_idx(F f, const X& x) { return impl::transform_tuple_with_idx_impl(f, x, make_index_seq()>{}); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // layout, simple linear nd layout with stride, static or dynamic supported namespace impl { template OPUS_H_D constexpr auto packed_stride_at(seq) { return (get(Shape{}) * ... * number<1>{}); } template OPUS_H_D constexpr auto packed_shape_to_stride_impl(seq) { return opus::make_tuple(packed_stride_at(make_index_seq{})...); } } template OPUS_H_D constexpr auto packed_shape_to_stride(const Shape&) { return impl::packed_shape_to_stride_impl(make_index_seq{}); } template OPUS_H_D constexpr decltype(auto) coord_to_linear(const Layout& layout, const Coord& coord) { static_assert(size() == size()); return embed(layout.stride(), coord); } // Shape/Stride/Coord, they are all tuples. if Coord is not false_type, will use merge_peepholed_tuple() to construct real coord template struct layout : public tuple, remove_cvref_t, remove_cvref_t> { using base = tuple, remove_cvref_t, remove_cvref_t>; using Shape = remove_cvref_t; using Stride = remove_cvref_t; using Coord = remove_cvref_t; // peepholed coord static constexpr index_t rank = Shape::size(); static_assert(Shape::size() == Stride::size()); static_assert(std::is_same_v || size, Shape, Coord>>() == rank, "Coord should be either false_type or a tuple with same size as Shape"); static constexpr index_t coord_rank = [](){ if constexpr (std::is_same_v) return rank; else return rank - tuple_count(Coord{}); }(); OPUS_H_D constexpr layout(const Shape& shape, const Stride& stride, const Coord& coord = {}) : base(shape, stride, coord){} OPUS_H_D constexpr layout(Shape&& shape, Stride&& stride, Coord&& coord = {}) : base(shape, stride, coord){} // get ith element from shape/stride. if no I, then get the shape/stride as tuple template OPUS_H_D constexpr decltype(auto) shape() { return get<0,I...>(static_cast(*this)); } template OPUS_H_D constexpr decltype(auto) shape() const { return get<0,I...>(static_cast(*this)); } template OPUS_H_D constexpr decltype(auto) stride() { return get<1,I...>(static_cast(*this)); } template OPUS_H_D constexpr decltype(auto) stride() const { return get<1,I...>(static_cast(*this)); } template OPUS_H_D constexpr decltype(auto) coord() { return get<2,I...>(static_cast(*this)); } template OPUS_H_D constexpr decltype(auto) coord() const { return get<2,I...>(static_cast(*this)); } template && ...), bool> = true> OPUS_H_D constexpr decltype(auto) operator()(Cs&&... cs) const { return this->operator()(opus::make_tuple(std::forward(cs)...)); } template , bool> = true> OPUS_H_D constexpr decltype(auto) operator()(InCoord&& c) const { if constexpr (std::is_same_v) return coord_to_linear(*this, c); else return coord_to_linear(*this, merge_peepholed_tuple(coord(), c)); } }; template struct layout_linear; template struct layout_cached; // use cached_vec to dispatch which layout implementation. cached_vec < 0 : "layout", cached_vec == 0 : "layout_linear", cached_vec > 0 : "layout_cached" template OPUS_H_D constexpr auto make_layout(Sx&& s, Sy&& t) { if constexpr (cached_vec < 0) return layout(std::forward(s), std::forward(t)); else if constexpr (cached_vec == 0) return layout_linear>(std::forward(s), std::forward(t)); else return layout_cached>(std::forward(s), std::forward(t)); } template OPUS_H_D constexpr auto make_layout(Sx&& s, Sy&& t, Sz&& c) { if constexpr (cached_vec < 0) return layout(std::forward(s), std::forward(t), std::forward(c)); if constexpr (cached_vec == 0) return layout_linear>(std::forward(s), std::forward(t), std::forward(c)); else return layout_cached>(std::forward(s), std::forward(t), std::forward(c)); } template && ...), bool> = true> OPUS_H_D constexpr auto make_layout(Ts&&... ss) { return make_layout(opus::make_tuple(ss...), packed_shape_to_stride(opus::make_tuple(ss...))); } template OPUS_H_D constexpr auto make_layout(S&& s) { return make_layout(std::forward(s), packed_shape_to_stride(s)); } template OPUS_H_D constexpr auto make_layout_packed(S&& s) { return make_layout(std::forward(s), packed_shape_to_stride(s)); } // same as single arg make_layout template OPUS_H_D constexpr auto make_layout_packed(Sx&& s, Sz&& c) { return make_layout(std::forward(s), packed_shape_to_stride(s), std::forward(c)); } template struct layout_linear : public remove_cvref_t{ using base = remove_cvref_t; template OPUS_H_D constexpr layout_linear(const Shape& shape, const Stride& stride, const Coord& coord = {}) : base(shape, stride, coord), linear_offset(0){} template OPUS_H_D constexpr layout_linear(Shape&& shape, Stride&& stride, Coord&& coord = {}) : base(shape, stride, coord), linear_offset(0){} template && ...), bool> = true> OPUS_H_D constexpr decltype(auto) operator()(Cs&&... cs) const { return this->operator()(opus::make_tuple(std::forward(cs)...)); } template , bool> = true> OPUS_H_D constexpr decltype(auto) operator()(InCoord&& c) const { if constexpr (std::is_same_v) return linear_offset + coord_to_linear(*this, c); else return linear_offset + coord_to_linear(*this, merge_peepholed_tuple(base::coord(), c)); } OPUS_H_D constexpr void inc(index_t offset) { linear_offset += offset; } OPUS_H_D constexpr layout_linear& operator+=(index_t offset) { inc(offset); return *this; } OPUS_H_D constexpr layout_linear operator+(index_t offset) const { layout_linear result(*this); result += offset; return result; } index_t linear_offset; }; template OPUS_H_D constexpr auto layout_to_vectorized_issue_space(); template OPUS_H_D constexpr auto layout_to_offsets(const Layout& u); template struct layout_cached : public remove_cvref_t { using base = remove_cvref_t; static constexpr index_t cached_vec = cached_vec_; static constexpr auto issue_space_vec = layout_to_vectorized_issue_space(); static constexpr index_t num_issues = get<0>(reduce_tuple_mul(issue_space_vec)).value; template OPUS_H_D constexpr layout_cached(const Shape& shape, const Stride& stride, const Coord& coord = {}) : base(shape, stride, coord), offsets{layout_to_offsets(static_cast(*this))}{} template OPUS_H_D constexpr layout_cached(Shape&& shape, Stride&& stride, Coord&& coord = {}) : base(shape, stride, coord), offsets{layout_to_offsets(static_cast(*this))}{} template && ...), bool> = true> OPUS_H_D constexpr decltype(auto) operator()(Cs&&... cs) const { return this->operator()(opus::make_tuple(std::forward(cs)...)); } template , bool> = true> OPUS_H_D constexpr decltype(auto) operator()(InCoord&& c) const { constexpr auto u_linear = make_layout<-1>(issue_space_vec); return offsets[u_linear(c)]; } OPUS_H_D constexpr void inc(index_t offset) { static_for([&](auto i){ offsets[i] += offset; }); } OPUS_H_D constexpr layout_cached& operator+=(index_t offset) { inc(offset); return *this; } OPUS_H_D constexpr layout_cached operator+(index_t offset) const { layout_cached result(*this); result += offset; return result; } array offsets; }; template struct is_layout : false_type {}; template struct is_layout> : true_type {}; template struct is_layout> : true_type {}; template struct is_layout> : true_type {}; template constexpr bool is_layout_v = is_layout>::value; template OPUS_H_D constexpr auto layout_to_issue_space() { using maybe_coord = std::conditional_t, typename Layout::Shape, typename Layout::Coord>; using issue_space_y = remove_cvref_t; using single_issue_space = remove_cvref_t{}, number()>{}))>; using fallback_issue_space_y = std::conditional_t>, single_issue_space, issue_space_y>; using issue_space = std::conditional_t, single_issue_space, fallback_issue_space_y>; return issue_space{}; } template OPUS_H_D constexpr auto vectorize_issue_space(issue_space, number = {}) { constexpr index_t vec_from_issue_space = get() - 1>(issue_space{}).value; // here we get the original last dim length(which should be y dim) static_assert(vec_from_issue_space % vec == 0, "please make sure requested vec size can be dividable of vec from issue space"); constexpr auto issue_space_vec = transform_tuple_with_idx([&](auto item, auto index){ // modify the last dim, divide it by vec. Result is still a tuple if constexpr (index.value == size() - 1) return number{}; else return item; }, issue_space{}); return issue_space_vec; } template OPUS_H_D constexpr auto layout_to_vectorized_issue_space() { constexpr auto issue_space = layout_to_issue_space(); constexpr auto issue_space_vec = vectorize_issue_space(issue_space, number{}); return issue_space_vec; } // Cache issue-space computations for load/store (avoids redundant evaluation across methods) template struct layout_load_traits { static constexpr auto issue_space = layout_to_issue_space(); static constexpr auto issue_space_vec = vectorize_issue_space(issue_space, number{}); static constexpr auto r_elem = get<0>(reduce_tuple_mul(issue_space_vec)); }; template struct layout_imm_offsets {}; // cached offsets for tr_load immediate-offset path template struct layout_imm_offsets { using L = remove_cvref_t; static constexpr auto u_linear = make_layout<-1>(layout_load_traits::issue_space_vec); static constexpr auto offsets = layout_to_offsets(L(typename L::Shape{}, typename L::Stride{}, typename L::Coord{})); }; // Runtime flat index → multi-index tuple (all index_t) — avoids per-iteration template instantiation template OPUS_H_D constexpr auto flat_to_coords(index_t flat, seq, tuple...>) { constexpr index_t strides[] = {impl::ford_stride(make_index_seq{}, seq{})...}, dims[] = {Ns...}; return opus::make_tuple(static_cast((flat / strides[Is]) % dims[Is])...); } // Pre-compute offsets via runtime loop — 1 coord_to_linear instantiation per layout instead of N template OPUS_H_D constexpr auto layout_to_offsets(const Layout& u) { using LT = layout_load_traits; constexpr auto issue_space_vec = LT::issue_space_vec; constexpr index_t num_issues = LT::r_elem.value, ndim = size>(); array offsets; for (index_t i = 0; i < num_issues; i++) offsets[i] = u(flat_to_coords(i, make_index_seq{}, issue_space_vec)); return offsets; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // vector, a wrapper for __attribute__((ext_vector_type(*))) template // V_ must be literal type, otherwise clang ext_vector_type will not recognize struct vector { static constexpr index_t N = N_; using value_type = remove_cvref_t; using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; template using vector_t = typename vector::type; template struct is_vector : false_type {}; template struct is_vector : true_type {}; template struct is_vector : true_type {}; template struct is_vector : true_type {}; template struct is_vector : true_type {}; template static constexpr bool is_vector_v = is_vector::value; namespace impl { template struct vector_traits_impl { using dtype = remove_cvref_t; static constexpr index_t size() { return 1; } }; template struct vector_traits_impl { using dtype = T; static constexpr index_t size() { return N; } }; template struct vector_traits_impl> { using dtype = T; static constexpr index_t size() { return N; } }; template struct vector_traits_impl> { using dtype = __type_pack_element<0, T...> /*TODO: use first type*/; static constexpr index_t size() { return sizeof...(T); } }; } template struct vector_traits : public impl::vector_traits_impl> {}; template OPUS_H_D constexpr std::enable_if_t, index_t> size(T&&) { return vector_traits::size(); /* vector size */} template OPUS_H_D constexpr std::enable_if_t, index_t> size() { return vector_traits::size(); /* vector size */} template struct get_value_type>> { using type = typename vector_traits::dtype; }; namespace impl { template struct vector_return_type_helper { using type = D; }; template struct vector_return_type_helper : std::common_type { static_assert(std::conjunction_v...>, "Types cannot contain reference_wrappers when D is void"); }; template using vector_return_type = opus::vector_t::type, sizeof...(Types)>; } template constexpr impl::vector_return_type make_vector(Types&&... t) { return {std::forward(t)...}; } namespace impl { template OPUS_H_D constexpr auto make_repeated_vector(T&& x, seq) { return opus::make_vector((void(Is), std::forward(x))...); } } // namespace impl template OPUS_H_D constexpr auto make_repeated_vector(T&& x) { return impl::make_repeated_vector(std::forward(x), make_index_seq{}); } template OPUS_H_D constexpr auto make_repeated_vector(T&& x, number) { return impl::make_repeated_vector(std::forward(x), make_index_seq{}); } // vector type can't return reference! error: non-const reference cannot bind to vector element template , bool> = true> OPUS_H_D constexpr typename vector_traits::dtype get(T const& t) { static_assert(I < vector_traits::size()); return t[I]; } template , bool> = true> OPUS_H_D constexpr typename vector_traits::dtype get(T&& t) { static_assert(I < vector_traits::size()); return t[I]; } namespace impl { template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1, seq, seq) { if constexpr (std::is_same_v, remove_cvref_t> && sizeof...(I0) > 1) { using R = vector_t>::dtype, sizeof...(I0) + sizeof...(I1)>; return __builtin_bit_cast(R, __builtin_shufflevector(t0, t1, I0..., (sizeof...(I0) + I1)...)); } else { return opus::make_vector(get(t0)..., get(t1)...); } } template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1, T2 const& t2, seq, seq, seq) { return opus::make_vector(get(t0)..., get(t1)..., get(t2)...); } template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, seq, seq, seq, seq) { return opus::make_vector(get(t0)..., get(t1)..., get(t2)..., get(t3)...); } } template OPUS_H_D constexpr auto concat_vector(T0 const& t0) { return t0; } template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1) { return impl::concat_vector(t0, t1, make_index_seq()>{}, make_index_seq()>{}); } template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1, T2 const& t2) { return impl::concat_vector(t0, t1, t2, make_index_seq{}, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) { return impl::concat_vector(t0, t1, t2, t3, make_index_seq{}, make_index_seq{}, make_index_seq{}, make_index_seq{}); } template OPUS_H_D constexpr auto concat_vector(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, Ts const&... ts) { return concat_vector(concat_vector(t0, t1, t2, t3), concat_vector(t4, ts...)); } template , bool> = true> OPUS_H_D constexpr void fill(T& a, typename vector_traits::dtype const& value) { if constexpr (size() <= 4) { static_for()>([&](auto i){ a[i.value] = value; }); } else { for (index_t i = 0; i < size(); ++i) a[i] = value; } // runtime loop for large vectors } template , bool> = true> OPUS_H_D constexpr void clear(T& a) { a = {}; } namespace impl { template, bool> = true> OPUS_H_D constexpr auto to_array_impl(const T& t, seq) { return opus::make_array(t[Is]...); } template, bool> = true> OPUS_H_D constexpr auto to_array_impl(const T& t, seq) { return opus::concat_array(to_array_impl(get(t), make_index_seq< size(get(T{})) >{})...); } template && !is_vector_v, bool> = true> OPUS_H_D constexpr vector_t to_vector_impl(const T& t, seq) { return {get(t)...}; } template && is_vector_v, bool> = true> OPUS_H_D constexpr vector_t to_vector_impl(const T& t, seq) { return opus::concat_vector(to_vector_impl(get(t))...); } } template, bool> = true> // vector type to array OPUS_H_D constexpr auto to_array(const T& t) { return impl::to_array_impl(t, make_index_seq()>{}); } template, bool> = true> // array of vector type to array OPUS_H_D constexpr auto to_array(const T& t) { return impl::to_array_impl(t, make_index_seq()>{}); } template, bool> = true> OPUS_H_D constexpr auto to_vector(const T& t) { return impl::to_vector_impl(t, make_index_seq()>{}); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // slice namespace impl { template, bool> = true> OPUS_H_D constexpr auto slice_impl(C&& c, seq) { if constexpr (sizeof...(Is) == 1) return opus::make_vector(get(c)...); else { using R = vector_t>::dtype, sizeof...(Is)>; return __builtin_bit_cast(R, __builtin_shufflevector(c, c, Is...)); } } template, bool> = true> OPUS_H_D constexpr auto slice_impl(C&& c, seq) { return opus::make_array(get(c)...); } template, bool> = true> OPUS_H_D constexpr auto slice_impl(C&& c, seq) { return opus::make_tuple(get(c)...); } template, bool> = true> OPUS_H_D constexpr auto slice_impl_i(C&& c, Ts... ss) { vector_t::dtype, len> r; index_t d = 0; static_for([&](auto i){r[d++] = c[i]; }, ss...); return r; } template, bool> = true> OPUS_H_D constexpr auto slice_impl_i(C&& c, Ts... ss) { array r; index_t d = 0; static_for([&](auto i){r[d++] = c[i]; }, ss...); return r; } template OPUS_H_D constexpr bool is_contiguous_seq(seq) { if constexpr (sizeof...(Is) < 2) return true; else { constexpr index_t idx[] = {Is...}; for (index_t i = 1; i < sizeof...(Is); ++i) { if (idx[i] != idx[i - 1] + 1) return false; } return true; } } template || is_array_v || is_tuple_v), bool> = true> OPUS_H_D constexpr auto set_slice_impl(C&& dst_c, V&& src_c, seq, seq) { using dst_t = remove_cvref_t; using src_t = remove_cvref_t; using scalar = typename vector_traits::dtype; constexpr index_t len = sizeof...(Ds); // Copy at dword granularity for sub-dword scalar types with dword-aligned contiguous slices if constexpr ((is_vector_v || is_array_v) && (is_vector_v || is_array_v) && is_contiguous_seq(seq{}) && is_contiguous_seq(seq{}) && sizeof(scalar) < 4 && len > 1) { constexpr index_t epd = 4 / sizeof(scalar); constexpr index_t d0 = seq::at(number<0>{}), s0 = seq::at(number<0>{}), dn = vector_traits::size(), sn = vector_traits::size(); if constexpr (d0 % epd == 0 && s0 % epd == 0 && len % epd == 0 && dn % epd == 0 && sn % epd == 0) { auto dst_i32 = __builtin_bit_cast(vector_t, dst_c); const auto src_i32 = __builtin_bit_cast(vector_t, src_c); static_for([&](auto i) { dst_i32[d0 / epd + i.value] = src_i32[s0 / epd + i.value]; }); dst_c = __builtin_bit_cast(dst_t, dst_i32); return; } } if constexpr (is_contiguous_seq(seq{}) && is_contiguous_seq(seq{}) && (is_vector_v || is_array_v) && len > 2) { constexpr index_t d0 = seq::at(number<0>{}), s0 = seq::at(number<0>{}); for (index_t i = 0; i < len; ++i) dst_c[d0 + i] = src_c[s0 + i]; // runtime loop avoids N-element fold instantiation } else { ((dst_c[Ds] = src_c[Ss]), ...); } } } // static/dynamic slice. SS could be either number, or const integer. Note tuple type does not support dynamic slice (ss is integral) // (1).[end] : 0.... end, (2).[start, end] : start...end, (3).[start, end, step], start...end but with step as interval (default is 1) template && (is_constant_v && ...), bool> = true> OPUS_H_D constexpr auto slice(C&& c, S&&.../*ss*/) { return impl::slice_impl(std::forward(c), make_index_seq<(S::value) ...>{}); } template && (std::is_integral_v && ...), bool> = true> OPUS_H_D constexpr auto slice(C&& c, S&&...ss) { return impl::slice_impl_i(std::forward(c), ss...); } template && (is_constant_v && ...), bool> = true> OPUS_H_D constexpr auto slice(C&& c, S&&.../*ss*/) { return impl::slice_impl(std::forward(c), make_index_seq<(S::value) ...>{}); } template && (std::is_integral_v && ...), bool> = true> OPUS_H_D constexpr auto slice(C&& c, S&&...ss) { return impl::slice_impl_i(std::forward(c), ss...); } template && (is_constant_v && ...), bool> = true> OPUS_H_D constexpr auto slice(C&& c, S&&.../*ss*/) { return impl::slice_impl(std::forward(c), make_index_seq<(S::value) ...>{}); } template || is_array_v || is_tuple_v) && (is_constant_v && ...), bool> = true> OPUS_H_D constexpr auto set_slice(C&& dst_c, V&& src_c, S&&.../*ss*/) { static_assert(std::is_same_v::dtype, typename vector_traits::dtype>); using dst_seq = make_index_seq<(S::value) ...>; return impl::set_slice_impl(std::forward(dst_c), std::forward(src_c), dst_seq{}, make_index_seq()>{}); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // BELOW IS AMDGPU SPECIFIC TYPES/ARCH/INTRINSICS ///////////////////////////////////////////////////////////////////////////////////////////////////////// // address space attribute #if defined(__HIP_DEVICE_COMPILE__) #define OPUS_LDS_ADDR __attribute__((address_space(3))) #else #define OPUS_LDS_ADDR #endif // dtype, suffix is "_t", and register corresponding ext_vector_type, and a specialization of is_dtype #define REGISTER_DTYPE(dtype_base_, dtype_impl_) \ using dtype_base_ ## _t = dtype_impl_; \ using dtype_base_ ## x1_t = dtype_base_ ## _t __attribute__((ext_vector_type(1 ))); \ using dtype_base_ ## x2_t = dtype_base_ ## _t __attribute__((ext_vector_type(2 ))); \ using dtype_base_ ## x4_t = dtype_base_ ## _t __attribute__((ext_vector_type(4 ))); \ using dtype_base_ ## x8_t = dtype_base_ ## _t __attribute__((ext_vector_type(8 ))); \ using dtype_base_ ## x16_t = dtype_base_ ## _t __attribute__((ext_vector_type(16))); \ using dtype_base_ ## x32_t = dtype_base_ ## _t __attribute__((ext_vector_type(32))); \ using dtype_base_ ## x64_t = dtype_base_ ## _t __attribute__((ext_vector_type(64))); \ template<> struct is_dtype : true_type {}; template struct is_dtype : false_type {}; template constexpr bool is_dtype_v = is_dtype>::value; // use this! REGISTER_DTYPE(fp32, float) #if __clang_major__ >= 20 // enable for rocm 7.0+ REGISTER_DTYPE(bf16, __bf16) REGISTER_DTYPE(fp16, __fp16) #else REGISTER_DTYPE(bf16, unsigned short) REGISTER_DTYPE(fp16, _Float16) #endif REGISTER_DTYPE(fp8 , _BitInt(8)) REGISTER_DTYPE(bf8 , unsigned _BitInt(8)) REGISTER_DTYPE(i32 , int) REGISTER_DTYPE(u32 , unsigned int) REGISTER_DTYPE(i16 , short) #if __clang_major__ >= 20 REGISTER_DTYPE(u16 , unsigned short) #endif REGISTER_DTYPE(i8 , signed char) REGISTER_DTYPE(u8 , unsigned char) /////////////////////////////////////////////////////////////////////////////////////////////////////////// // numeric_limits -- returns min/max/lowest/quiet_nan/infinity in the *original* dtype // (see finfo below for float-valued properties like eps/max/min/tiny) template struct numeric_limits; template<> struct numeric_limits { static constexpr unsigned int bin_min = 0x00800000, bin_max = 0x7F7FFFFF, bin_lowest = 0xFF7FFFFF, bin_qnan = 0x7FC00000, bin_inf = 0x7F800000; OPUS_H_D static constexpr fp32_t min() { return __builtin_bit_cast(fp32_t, bin_min); } OPUS_H_D static constexpr fp32_t max() { return __builtin_bit_cast(fp32_t, bin_max); } OPUS_H_D static constexpr fp32_t lowest() { return __builtin_bit_cast(fp32_t, bin_lowest); } OPUS_H_D static constexpr fp32_t quiet_nan() { return __builtin_bit_cast(fp32_t, bin_qnan); } OPUS_H_D static constexpr fp32_t infinity() { return __builtin_bit_cast(fp32_t, bin_inf); } }; template<> struct numeric_limits { static constexpr unsigned short bin_min = 0x0400, bin_max = 0x7BFF, bin_lowest = 0xFBFF, bin_qnan = 0x7E00, bin_inf = 0x7C00; OPUS_H_D static constexpr fp16_t min() { return __builtin_bit_cast(fp16_t, bin_min); } OPUS_H_D static constexpr fp16_t max() { return __builtin_bit_cast(fp16_t, bin_max); } OPUS_H_D static constexpr fp16_t lowest() { return __builtin_bit_cast(fp16_t, bin_lowest); } OPUS_H_D static constexpr fp16_t quiet_nan() { return __builtin_bit_cast(fp16_t, bin_qnan); } OPUS_H_D static constexpr fp16_t infinity() { return __builtin_bit_cast(fp16_t, bin_inf); } }; template<> struct numeric_limits { static constexpr unsigned short bin_min = 0x0080, bin_max = 0x7F7F, bin_lowest = 0xFF7F, bin_qnan = 0x7FC0, bin_inf = 0x7F80; OPUS_H_D static constexpr bf16_t min() { return __builtin_bit_cast(bf16_t, bin_min); } OPUS_H_D static constexpr bf16_t max() { return __builtin_bit_cast(bf16_t, bin_max); } OPUS_H_D static constexpr bf16_t lowest() { return __builtin_bit_cast(bf16_t, bin_lowest); } OPUS_H_D static constexpr bf16_t quiet_nan() { return __builtin_bit_cast(bf16_t, bin_qnan); } OPUS_H_D static constexpr bf16_t infinity() { return __builtin_bit_cast(bf16_t, bin_inf); } }; // fp8 E4M3: gfx950=OCP(ieee-like, NaN=0x7F), gfx942=fnuz(NaN=0x80). No infinity in either format. // NOTE: __builtin_bit_cast with _BitInt(8) is not yet constexpr in clang, so use static_cast via signed char. template<> struct numeric_limits { #if defined(__gfx942__) static constexpr unsigned char bin_min = 0x08, bin_max = 0x7F, bin_lowest = 0xFF, bin_qnan = 0x80, bin_inf = 0x00; #else static constexpr unsigned char bin_min = 0x08, bin_max = 0x7E, bin_lowest = 0xFE, bin_qnan = 0x7F, bin_inf = 0x00; #endif OPUS_H_D static constexpr fp8_t min() { return static_cast(static_cast(bin_min)); } OPUS_H_D static constexpr fp8_t max() { return static_cast(static_cast(bin_max)); } OPUS_H_D static constexpr fp8_t lowest() { return static_cast(static_cast(bin_lowest)); } OPUS_H_D static constexpr fp8_t quiet_nan() { return static_cast(static_cast(bin_qnan)); } OPUS_H_D static constexpr fp8_t infinity() { return static_cast(static_cast(bin_inf)); } }; // bf8 E5M2: gfx950=OCP(ieee, has inf=0x7C, NaN=0x7E), gfx942=fnuz(no inf, NaN=0x80) template<> struct numeric_limits { #if defined(__gfx942__) static constexpr unsigned char bin_min = 0x04, bin_max = 0x7F, bin_lowest = 0xFF, bin_qnan = 0x80, bin_inf = 0x00; #else static constexpr unsigned char bin_min = 0x04, bin_max = 0x7B, bin_lowest = 0xFB, bin_qnan = 0x7F, bin_inf = 0x7C; #endif OPUS_H_D static constexpr bf8_t min() { return static_cast(bin_min); } OPUS_H_D static constexpr bf8_t max() { return static_cast(bin_max); } OPUS_H_D static constexpr bf8_t lowest() { return static_cast(bin_lowest); } OPUS_H_D static constexpr bf8_t quiet_nan() { return static_cast(bin_qnan); } OPUS_H_D static constexpr bf8_t infinity() { return static_cast(bin_inf); } }; template<> struct numeric_limits { OPUS_H_D static constexpr i32_t min() { return -2147483647 - 1; } OPUS_H_D static constexpr i32_t max() { return 2147483647; } OPUS_H_D static constexpr i32_t lowest() { return -2147483647 - 1; } OPUS_H_D static constexpr i32_t quiet_nan() { return 0; } OPUS_H_D static constexpr i32_t infinity() { return 0; } }; template<> struct numeric_limits { OPUS_H_D static constexpr u32_t min() { return 0; } OPUS_H_D static constexpr u32_t max() { return 4294967295U; } OPUS_H_D static constexpr u32_t lowest() { return 0; } OPUS_H_D static constexpr u32_t quiet_nan() { return 0; } OPUS_H_D static constexpr u32_t infinity() { return 0; } }; template<> struct numeric_limits { OPUS_H_D static constexpr i16_t min() { return -32768; } OPUS_H_D static constexpr i16_t max() { return 32767; } OPUS_H_D static constexpr i16_t lowest() { return -32768; } OPUS_H_D static constexpr i16_t quiet_nan() { return 0; } OPUS_H_D static constexpr i16_t infinity() { return 0; } }; #if __clang_major__ >= 20 template<> struct numeric_limits { OPUS_H_D static constexpr u16_t min() { return 0; } OPUS_H_D static constexpr u16_t max() { return 65535; } OPUS_H_D static constexpr u16_t lowest() { return 0; } OPUS_H_D static constexpr u16_t quiet_nan() { return 0; } OPUS_H_D static constexpr u16_t infinity() { return 0; } }; #endif template<> struct numeric_limits { OPUS_H_D static constexpr i8_t min() { return -128; } OPUS_H_D static constexpr i8_t max() { return 127; } OPUS_H_D static constexpr i8_t lowest() { return -128; } OPUS_H_D static constexpr i8_t quiet_nan() { return 0; } OPUS_H_D static constexpr i8_t infinity() { return 0; } }; template<> struct numeric_limits { OPUS_H_D static constexpr u8_t min() { return 0; } OPUS_H_D static constexpr u8_t max() { return 255; } OPUS_H_D static constexpr u8_t lowest() { return 0; } OPUS_H_D static constexpr u8_t quiet_nan() { return 0; } OPUS_H_D static constexpr u8_t infinity() { return 0; } }; /////////////////////////////////////////////////////////////////////////////////////////////////////////// // finfo -- like torch.finfo: eps/max/min/tiny as float, bits as int template struct finfo; template<> struct finfo { static constexpr int bits = 32; OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x34000000u); } // 2^-23 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x7F7FFFFFu); } // 3.4028235e+38 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xFF7FFFFFu); } // -3.4028235e+38 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x00800000u); } // 2^-126 }; template<> struct finfo { static constexpr int bits = 16; OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3A800000u); } // 2^-10 = 9.765625e-4 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x477FE000u); } // 65504.0 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC77FE000u); } // -65504.0 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x38800000u); } // 2^-14 }; template<> struct finfo { static constexpr int bits = 16; OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3C000000u); } // 2^-7 = 0.0078125 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x7F7F0000u); } // 3.389531e+38 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xFF7F0000u); } // -3.389531e+38 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x00800000u); } // 2^-126 }; // fp8 E4M3: gfx950=OCP(float8_e4m3fn, bias=7), gfx942=fnuz(float8_e4m3fnuz, bias=8) template<> struct finfo { static constexpr int bits = 8; OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3E000000u); } // 2^-3 = 0.125 #if defined(__gfx942__) OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x43700000u); } // 240.0 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC3700000u); } // -240.0 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x3C000000u); } // 2^-7 = 0.0078125 #else OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x43E00000u); } // 448.0 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC3E00000u); } // -448.0 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x3C800000u); } // 2^-6 = 0.015625 #endif }; // bf8 E5M2: gfx950=OCP(float8_e5m2, bias=15), gfx942=fnuz(float8_e5m2fnuz, bias=16) template<> struct finfo { static constexpr int bits = 8; #if defined(__gfx942__) OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3E000000u); } // 2^-3 = 0.125 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x47600000u); } // 57344.0 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC7600000u); } // -57344.0 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x38000000u); } // 2^-15 #else OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3E800000u); } // 2^-2 = 0.25 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x47600000u); } // 57344.0 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC7600000u); } // -57344.0 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x38800000u); } // 2^-14 #endif }; template<> struct finfo { static constexpr int bits = 8; OPUS_H_D static constexpr float max() { return 127.0f; } OPUS_H_D static constexpr float min() { return -128.0f; } }; template && (is_constant_v && ...), bool> = true> OPUS_H_D constexpr auto slice(C&& container, S&&.../*ss*/) { return container; } // TODO: fallback slice a normal value does nonthing ///////////////////////////////////////////////////////////////////////////////////////////////////////// // type cast OPUS_D bf16_t fp32_to_bf16_rtn_asm(const float& x) { union { float f; u32_t i; } u = {x}; constexpr u32_t f32_nan = 0x7fff0000; constexpr u32_t round_bias = 0x7fff; u32x2_t check_nan; u32_t tmp; asm volatile("\nv_cmp_u_f32 %0, %2, %2 \nv_bfe_u32 %1, %2, 16, 1 \nv_add3_u32 %1, %2, %1, %3 \nv_cndmask_b32 %2, %1, %4, %0 \nv_lshrrev_b32 %2, 16, %2 \n" : "=s"(check_nan), "+v"(tmp), "+v"(u.f) : "v"(round_bias), "v"(f32_nan)); return bf16_t(u.i); } OPUS_D constexpr auto fp16_to_fp32(const fp16_t& x) { return static_cast(x); } OPUS_D constexpr auto fp32_to_fp16(const fp32_t& x) { return static_cast(x); } OPUS_D constexpr auto bf16_to_fp32(const bf16_t& x) { union { u32_t i; float f; } u = {static_cast(__builtin_bit_cast(unsigned short, x)) << 16}; return u.f;} OPUS_D constexpr unsigned short fp32_to_bf16_rtn_raw(float f) { unsigned int bits = __builtin_bit_cast(unsigned int, f); if(~bits & 0x7f800000) { bits += 0x7fff + ((bits >> 16) & 1); /* Round to nearest even */ } else if(bits & 0xffff) { bits |= 0x10000; /* Preserve signaling NaN */ } return static_cast(bits >> 16); } #if (defined(__gfx950__) || defined(__gfx1250__)) && __clang_major__ >= 20 template // gfx950/gfx1250 has instruction conversion, leave 'rm' here for compatiblity OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number = {}) { return static_cast(x); } #else template // 0:standard, 1:truncate_with_nan, 2:truncate, 3:standard asm 4:rta_asm(round to nearest away) OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number = {}) { if constexpr (rm == 0) {return __builtin_bit_cast(bf16_t, fp32_to_bf16_rtn_raw(x)); } else if constexpr (rm == 1) {u32_t z = __builtin_bit_cast(u32_t, x); return __builtin_bit_cast(bf16_t, static_cast(z | (!(~z & 0x7f800000) && (z & 0xffff) ? 0x10000 : 0) >> 16)); } else if constexpr (rm == 2) {u32_t z = __builtin_bit_cast(u32_t, x); return __builtin_bit_cast(bf16_t, static_cast(z >> 16)); } else if constexpr (rm == 3) { return fp32_to_bf16_rtn_asm(x); } } #endif #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" #pragma clang diagnostic ignored "-Wc++20-extensions" // scalar fp8 <-> fp32 via packed intrinsics (lo slot only). NOT constexpr: clang eagerly rejects non-template // constexpr functions containing GPU builtins (__builtin_amdgcn_cvt_*) that can never be compile-time evaluated. // Template constexpr (packed variants, OPUS_CAST_DEFINE) survives because the check is deferred to instantiation. // TODO: we may remove constexpr from cast in the future OPUS_D auto fp32_to_fp8(const fp32_t& x) { // int w; w = __builtin_amdgcn_cvt_pk_fp8_f32(x, 0.0f, w, /*sel=lo*/0); // return __builtin_bit_cast(fp8_t, static_cast(w)); #if defined(__gfx938__) || defined(__gfx946__) int w; w = __builtin_hcu_cvt_pk_fp8_f32(x, 0.0f, w, 0); return __builtin_bit_cast(fp8_t, static_cast(w)); #else return 0.0f; #endif } OPUS_D auto fp8_to_fp32(const fp8_t& x) { // int w = static_cast(__builtin_bit_cast(unsigned char, x)); // return __builtin_amdgcn_cvt_f32_fp8(w, /*byte=*/0); #if defined(__gfx938__) || defined(__gfx946__) int w = static_cast(__builtin_bit_cast(unsigned char, x)); return __builtin_hcu_cvt_f32_fp8(w, 0, 0, 0); #else return 0.0f; #endif } OPUS_D constexpr auto fp32_to_fp32(const fp32_t& x) { return x; } OPUS_D constexpr auto fp32_to_i8(const fp32_t& x) { return static_cast(x); } OPUS_D constexpr auto i8_to_fp32(const i8_t& x) { return static_cast(x); } #pragma clang diagnostic pop #define OPUS_CAST_DEFINE(d_, s_) template && std::is_same_v, bool> = true> \ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return s_ ## _to_ ## d_(s, std::forward(aux)...); } OPUS_CAST_DEFINE(fp16, fp32) OPUS_CAST_DEFINE(fp32, fp16) OPUS_CAST_DEFINE(bf16, fp32) OPUS_CAST_DEFINE(fp32, bf16) OPUS_CAST_DEFINE(fp8, fp32) OPUS_CAST_DEFINE(fp32, fp8) OPUS_CAST_DEFINE(fp32, fp32) OPUS_CAST_DEFINE(i8, fp32) OPUS_CAST_DEFINE(fp32, i8) namespace impl { // implement a "pack" of data, storage should pad to multiple of byte(8bit) template struct dpacks { using storage = remove_cvref_t; static constexpr unsigned int bits = bits_; static constexpr unsigned int mask = (1 << bits) - 1; static constexpr bool is_signed = is_signed_; static constexpr unsigned int num_packs = sizeof(storage) * 8 / bits; // we will not check if evenly divided or not here OPUS_H_D constexpr storage operator[](index_t i) const { return (value >> (i * bits)) & mask; } // NOTE: not efficient, better use v_bfi/v_bfe/v_perm on device template OPUS_H_D constexpr storage operator[](number) const { return (value >> (I * bits)) & mask; } // NOTE: not efficient, better use v_bfi/v_bfe/v_perm on device storage value; }; template struct fpacks : dpacks { static constexpr unsigned int exp_bits = exp_bits_; static constexpr unsigned int mantissa_bits = mantissa_bits_; }; } // namespace impl template struct is_packs : false_type {}; template struct is_packs> : true_type {}; template struct is_packs> : true_type {}; template static constexpr bool is_packs_v = is_packs>::value; // how many real data within one byte template struct num_packs { static constexpr int value = 1; }; template struct num_packs>> { static constexpr int value = T::num_packs; }; template static constexpr int num_packs_v = num_packs::value; template struct sizeof_bits { static constexpr int value = int(sizeof(T) * 8); }; template <> struct sizeof_bits { static constexpr int value = 0; }; template struct sizeof_bits> { static constexpr int value = impl::dpacks::bits; }; template struct sizeof_bits> { static constexpr int value = impl::fpacks::bits; }; template static constexpr auto sizeof_bits_v = sizeof_bits::value; #define OPUS_DEFINE_DPACKS(name_, storage_, bits_, is_signed_) \ struct name_ : opus::impl::dpacks { using base = opus::impl::dpacks; }; \ template<> struct sizeof_bits { static constexpr int value = name_::bits; }; template<> struct is_packs : true_type {}; template<> struct is_dtype : true_type {}; #define OPUS_DEFINE_FPACKS(name_, storage_, bits_, exp_bits_, mantissa_bits_, is_signed_) \ struct name_ : opus::impl::fpacks {using base = opus::impl::fpacks; }; \ template<> struct sizeof_bits { static constexpr int value = name_::bits; }; template<> struct is_packs : true_type {}; template<> struct is_dtype : true_type {}; // NOTE: convention here. The subbyte type below is indeed "packed" data. e.g. fp4_t, underneath it is fp4x2 in one byte, but we don't name it this way // This is different from cutlass convention (e.g float4_e2m1_t, but storage is unsigned char, hence an array of float4_e2m1_t will be expanded), and different from ck convention(explicitly name it fp4x2_t) OPUS_DEFINE_DPACKS(int4_t , unsigned char, 4, true) // int4x2 OPUS_DEFINE_DPACKS(uint4_t, unsigned char, 4, false) // uint4x2 OPUS_DEFINE_FPACKS(fp4_t, unsigned char, 4, 2, 1, true) // fp4x2 OPUS_DEFINE_FPACKS(e8m0_t, unsigned char, 8, 8, 0, false) // fp4x2 // finfo specializations for subbyte/packed types (defined after OPUS_DEFINE_FPACKS) // fp4 E2M1: 1 sign, 2 exp, 1 mantissa, bias=1 template<> struct finfo { static constexpr int bits = 4; OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3F000000u); } // 2^-1 = 0.5 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x40C00000u); } // 6.0 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC0C00000u); } // -6.0 OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x3F800000u); } // 1.0 }; // e8m0: 8-bit exponent only, unsigned, bias=127 template<> struct finfo { static constexpr int bits = 8; OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3F800000u); } // 1.0 OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x7F000000u); } // 2^127 OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0x00400000u); } // 2^-127 (unsigned, no negative) OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x00400000u); } // 2^-127 }; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wuninitialized" #pragma clang diagnostic ignored "-Wc++20-extensions" template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp8_packed_x2(const S& s, number = {}) { int w ; #if defined(__gfx938__) || defined(__gfx946__) w = __builtin_hcu_cvt_pk_fp8_f32(s[0], s[1], w, (bool)sel); #else // w = __builtin_amdgcn_cvt_pk_fp8_f32(s[0], s[1], w, sel); w = 0; #endif return __builtin_bit_cast(fp8x2_t, static_cast(w)); } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp8_packed_x4(const S& s) { int w ; #if defined(__gfx938__) || defined(__gfx946__) w = __builtin_hcu_cvt_pk_fp8_f32(s[0], s[1], w, false); w = __builtin_hcu_cvt_pk_fp8_f32(s[2], s[3], w, true); #else // w = __builtin_amdgcn_cvt_pk_fp8_f32(s[0], s[1], w, 0); w = __builtin_amdgcn_cvt_pk_fp8_f32(s[2], s[3], w, 1); w = 0; #endif return __builtin_bit_cast(fp8x4_t, w); } template, bool> = true> OPUS_D constexpr decltype(auto) fp8_to_fp32_packed_x2(const S& s, number = {}) { union { int bitwise; S f8_packs[2]; } value; value.f8_packs[0] = s; #if defined(__gfx938__) || defined(__gfx946__) return __builtin_hcu_cvt_pk_f32_fp8(value.bitwise, sel); #else // return __builtin_amdgcn_cvt_pk_f32_fp8(value.bitwise, sel); return 0; #endif } template, bool> = true> OPUS_D constexpr decltype(auto) fp8_to_fp32_packed_x4(const S& s) { int bitwise = __builtin_bit_cast(int, s); #if defined(__gfx938__) || defined(__gfx946__) auto x = __builtin_hcu_cvt_pk_f32_fp8(bitwise, 0); auto y = __builtin_hcu_cvt_pk_f32_fp8(bitwise, 1); return fp32x4_t{x[0], x[1], y[0], y[1]}; #else // auto x = __builtin_amdgcn_cvt_pk_f32_fp8(bitwise, 0); auto y = __builtin_amdgcn_cvt_pk_f32_fp8(bitwise, 1); return fp32x4_t{0, 0, 0, 0}; #endif } namespace impl { template OPUS_D constexpr decltype(auto) fold_as_tuple_of_vec(const S& s, seq) { static_assert(size() % sizeof...(Xs) == 0); constexpr index_t Y_len = size() / sizeof...(Xs); auto gen_ = [&](number, seq){ return vector_t, Y_len>{get(s)...}; }; return make_tuple(gen_(number{}, make_index_seq{})...); } template OPUS_D constexpr decltype(auto) fold_as_tuple_of_arr(const S& s, seq) { static_assert(size() % sizeof...(Xs) == 0); constexpr index_t Y_len = size() / sizeof...(Xs); auto gen_ = [&](number, seq){ return array, Y_len>{get(s)...}; }; return make_tuple(gen_(number{}, make_index_seq{})...); } template || is_vector_v || is_array_v, bool> = true> OPUS_D constexpr decltype(auto) fold_as_container_of_vec(const S& s, number) { static_assert(size() % fold_size == 0); return fold_as_tuple_of_vec(s, make_index_seq() / fold_size>{}); } template || is_vector_v || is_array_v, bool> = true> OPUS_D constexpr decltype(auto) fold_as_container_of_arr(const S& s, number) { static_assert(size() % fold_size == 0); return fold_as_tuple_of_arr(s, make_index_seq() / fold_size>{}); } // Unfold a tuple-of-sub-results (produced by auto-fold cast) back into a flat container. Used in pair with above // matching the original input container type OrigS. // OrigS is vector -> flat vector_t // OrigS is array -> flat array // OrigS is tuple -> flat tuple template OPUS_D constexpr auto unfold_as_tuple(const Tup& tup, number, seq) { return make_tuple(get(getv(tup))...); } template, bool> = true> OPUS_D constexpr auto unfold_from_container(const Tup& tup) { using inner_t = remove_cvref_t(tup))>; using elem_t = get_value_t; constexpr index_t outer_n = opus::size(); constexpr index_t inner_n = opus::size(); constexpr index_t total_n = outer_n * inner_n; if constexpr (is_vector_v) { vector_t r; static_for([&](auto f) { r[f.value] = get(getv(tup)); }); return r; } else if constexpr (is_array_v) { array r; static_for([&](auto f) { r[f.value] = get(getv(tup)); }); return r; } else { /* tuple */ return unfold_as_tuple(tup, number{}, make_index_seq{}); } } } // namespace impl #if defined(__gfx950__) template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x2(const S& s, float scale = 1.0f, number = {}) { u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[0], s[1], scale, sel); return __builtin_bit_cast(array, static_cast(w)); } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x4(const S& s, float scale = 1.0f) { u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[0], s[1], scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[2], s[3], scale, 1); return __builtin_bit_cast(array, static_cast(w)); } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x8(const S& s, float scale = 1.0f) { u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[0], s[1], scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[2], s[3], scale, 1); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[4], s[5], scale, 2); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(w, s[6], s[7], scale, 3); return __builtin_bit_cast(array, w); } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x2(const S& s, float scale = 1.0f, number = {}) { return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(static_cast(__builtin_bit_cast(u8_t, s)), scale, sel); } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x4(const S& s, float scale = 1.0f) { auto ss = static_cast(__builtin_bit_cast(u16_t, s)); auto x = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(ss, scale, 0); auto y = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(ss, scale, 1); return fp32x4_t{x[0], x[1], y[0], y[1]}; } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& s, float scale = 1.0f) { auto ss = static_cast(__builtin_bit_cast(u32_t, s)); auto x = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(ss, scale, 0); auto y = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(ss, scale, 1); auto z = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(ss, scale, 2); auto w = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(ss, scale, 3); return fp32x8_t{x[0], x[1], y[0], y[1], z[0], z[1], w[0], w[1]}; } template, bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& s, float scale = 1.0f, number = {}) { union { unsigned int bitwise; fp4_t fp4_pack[4]; } value; value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(value.bitwise, s, scale, sel); return value.fp4_pack[0]; } template, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& s, float scale = 1.0f, number = {}) { return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(s, scale, sel); } #elif defined(__gfx1250__) // gfx1250: pk8 builtins convert 8 fp4 <-> 8 f32 at once // f32->fp4: __builtin_amdgcn_cvt_scalef32_pk8_fp4_f32(v8f32 src, float scale) -> i32 // fp4->f32: __builtin_amdgcn_cvt_scale_pk8_f32_fp4(i32 src, i32 scale_sel, i32 imm) -> v8f32 // scale_sel = e8m0 scale byte (imm selects which byte), e8m0: val = 2^(byte-127), so 1.0 = 0x7F // extract e8m0 from float: biased exponent = (float_bits >> 23) & 0xFF template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x2(const S& s, float scale = 1.0f, number = {}) { fp32x8_t v{s[0], s[1], 0, 0, 0, 0, 0, 0}; u32_t w = __builtin_amdgcn_cvt_scalef32_pk8_fp4_f32(v, scale); return __builtin_bit_cast(array, static_cast(w)); } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x4(const S& s, float scale = 1.0f) { fp32x8_t v{s[0], s[1], s[2], s[3], 0, 0, 0, 0}; u32_t w = __builtin_amdgcn_cvt_scalef32_pk8_fp4_f32(v, scale); return __builtin_bit_cast(array, static_cast(w)); } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x8(const S& s, float scale = 1.0f) { u32_t w = __builtin_amdgcn_cvt_scalef32_pk8_fp4_f32(s, scale); return __builtin_bit_cast(array, w); } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x2(const S& s, float scale = 1.0f, number = {}) { i32_t e = (__builtin_bit_cast(i32_t, scale) >> 23) & 0xFF; i32_t scale_e8m0 = e * static_cast(0x01010101); fp32x8_t r = __builtin_amdgcn_cvt_scale_pk8_f32_fp4(static_cast(__builtin_bit_cast(u8_t, s)), scale_e8m0, 0); return fp32x2_t{r[0], r[1]}; } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x4(const S& s, float scale = 1.0f) { i32_t e = (__builtin_bit_cast(i32_t, scale) >> 23) & 0xFF; i32_t scale_e8m0 = e * static_cast(0x01010101); fp32x8_t r = __builtin_amdgcn_cvt_scale_pk8_f32_fp4(static_cast(__builtin_bit_cast(u16_t, s)), scale_e8m0, 0); return fp32x4_t{r[0], r[1], r[2], r[3]}; } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& s, float scale = 1.0f) { i32_t e = (__builtin_bit_cast(i32_t, scale) >> 23) & 0xFF; i32_t scale_e8m0 = e * static_cast(0x01010101); fp32x8_t r = __builtin_amdgcn_cvt_scale_pk8_f32_fp4(static_cast(__builtin_bit_cast(u32_t, s)), scale_e8m0, 0); return fp32x8_t{r[0], r[1], r[2], r[3], r[4], r[5], r[6], r[7]}; } // bf16<->fp4 stubs for gfx1250 (no pk bf16<->fp4 builtins available) template, bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number = {}) { return fp4_t{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f, number = {}) { return bf16x2_t{}; } #else template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp32_to_fp4_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return array{}; } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return fp32x2_t{}; } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x4(const S& /*s*/, float /*scale*/ = 1.0f) { return fp32x4_t{}; } template>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_packed_x8(const S& /*s*/, float /*scale*/ = 1.0f) { return fp32x8_t{}; } template, bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return fp4_t{}; } template, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_packed_x2(const S& /*s*/, float /*scale*/ = 1.0f) { return bf16x2_t{}; } #endif #pragma clang diagnostic pop template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp32_to_fp8_packed_x2(s, std::forward(aux)...); } template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp32_to_fp8_packed_x4(s, std::forward(aux)...); } template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp8_to_fp32_packed_x2(s, std::forward(aux)...); } template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp8_to_fp32_packed_x4(s, std::forward(aux)...); } template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp32_to_fp4_packed_x2(s, std::forward(aux)...); } template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp32_to_fp4_packed_x4(s, std::forward(aux)...); } template && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp32_to_fp4_packed_x8(s, std::forward(aux)...); } template> && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp4_to_fp32_packed_x2(s, std::forward(aux)...); } template> && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp4_to_fp32_packed_x4(s, std::forward(aux)...); } template> && std::is_same_v, bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return fp4_to_fp32_packed_x8(s, std::forward(aux)...); } namespace impl { // rocm-7.1.1, when there are multiple invokes of this kernel (across different __global__ in same compile target ?) will fail to inline below function template, bool> = true> OPUS_D constexpr decltype(auto) cast_impl(const S& s, seq, Aux&&... aux) { return impl::vector_return_type(get(s), std::forward(aux)...))...>{cast(get(s), std::forward(aux)...)...}; } //return opus::make_vector(cast(get(s), std::forward(aux)...)...); } template, bool> = true> OPUS_D constexpr decltype(auto) cast_impl(const S& s, seq, Aux&&... aux) { return tuple(get(s), std::forward(aux)...))>...>(cast(get(s), std::forward(aux)...)...); } // return opus::make_tuple(cast(get(s), std::forward(aux)...) ... ); } template, bool> = true> OPUS_D constexpr decltype(auto) cast_impl(const S& s, seq, Aux&&... aux) { return impl::array_return_type(get(s), std::forward(aux)...))...>{cast(get(s), std::forward(aux)...)...}; } // return opus::make_array(cast(get(s), std::forward(aux)...)...); } } // entry point for vectorized cast(), non-dpacks template || is_tuple_v || is_array_v) && !is_packs_v && !is_packs_v>) && !(is_any_of_v&& std::is_same_v) && !(is_any_of_v&& std::is_same_v) , bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { if constexpr (std::is_same_v, fp32_t> && size() % 4 == 0 && std::is_same_v) { // fp32 -> fp8 , x4N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<4>{}), make_index_seq() / 4>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp32_t> && size() % 2 == 0 && std::is_same_v) { // fp32 -> fp8 , x2N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<2>{}), make_index_seq() / 2>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp8_t> && size() % 4 == 0 && std::is_same_v) { // fp8 -> fp32, x4N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<4>{}), make_index_seq() / 4>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp8_t> && size() % 2 == 0 && std::is_same_v) { // fp8 -> fp32, x2N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<2>{}), make_index_seq() / 2>{}, std::forward(aux)...)); } else if constexpr (is_vector_v && size() > 16 && sizeof...(Aux) == 0) { return __builtin_convertvector(s, vector_t()>); } else return impl::cast_impl(s, make_index_seq()>{}, std::forward(aux)...); } // entry point for vectorized cast(), for dpacks template || is_tuple_v || is_array_v) && (is_packs_v || is_packs_v>)) && !(is_any_of_v && std::is_same_v) // fp32 && !(is_any_of_v, array, array, tuple_array, tuple_array, tuple_array> && std::is_same_v) , bool> = true> OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { constexpr index_t num_packs_ = [&](){ // TODO: how to consider both D and S are packs? if constexpr (is_packs_v) { static_assert(size() % D::num_packs == 0); return D::num_packs; } // TODO: do not support cast pack data one by one else { return get_value_t::num_packs; } }(); if constexpr (std::is_same_v, fp32_t> && size() % 8 == 0 && std::is_same_v) { // fp32 -> fp4 , x8N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<8>{}), make_index_seq() / 8>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp32_t> && size() % 4 == 0 && std::is_same_v) { // fp32 -> fp4 , x4N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<4>{}), make_index_seq() / 4>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp32_t> && size() % 2 == 0 && std::is_same_v) { // fp32 -> fp4 , x2N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number<2>{}), make_index_seq() / 2>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp4_t> && size() % 4 == 0) { // fp4 -> fp32 , x8N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_arr(s, number<4>{}), make_index_seq() / 4>{}, std::forward(aux)...)); } else if constexpr (std::is_same_v, fp4_t> && size() % 2 == 0) { // fp4 -> fp32 , x4N return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_arr(s, number<2>{}), make_index_seq() / 2>{}, std::forward(aux)...)); } else return impl::unfold_from_container(impl::cast_impl(impl::fold_as_container_of_vec(s, number{}), make_index_seq() / num_packs_>{}, std::forward(aux)...)); } #undef OPUS_DEFINE_DPACKS #undef OPUS_DEFINE_FPACKS #undef OPUS_CAST_DEFINE ///////////////////////////////////////////////////////////////////////////////////////////////////////// // arch // // ---- HIPCC compilation model (clang-based) ---- // hipcc compiles each translation unit in TWO passes: host pass, then device pass. // // Host pass : __device__ functions are fully parsed, name-resolved, template-instantiated, and constexpr/static_assert evaluated. Only machine code generation is skipped. // Device pass: __host__ functions are truly skipped -- not parsed, not instantiated, not checked. // // Key consequences: // 1. Architecture macros (__GFX9__, __gfx950__, etc.) are defined ONLY during the device pass. Any #if guard on them will take the #else branch during the host pass. // 2. __device__ constexpr variables and static_asserts inside __device__ templates are still evaluated during the host pass (since templates may be instantiated from __global__). // 3. If your device code relies on arch-specific preprocessor branches, consider guarding the entire implementation with #if defined(__HIP_DEVICE_COMPILE__) to skip the host pass. // // ---- get_warp_size() / get_smem_size() ---- // OPUS_H_D constexpr -- safe to use everywhere: template defaults, static_assert, constexpr variables, __shared__ array sizes, host launch-parameter calculations, etc. // During the host pass (arch macros absent), they return safe defaults: // get_warp_size() -> 64 (GFX9 default), 32 for gfx1250 (wave32) // get_smem_size() -> 65536 (64 KB, non-gfx950 default) // Note: __builtin_amdgcn_wavefrontsize() is NOT constexpr in clang, so it cannot be used in template arguments, static_assert, or if constexpr. Prefer get_warp_size() which uses // preprocessor arch detection to provide a constexpr result. // // ---- query_warp_size() / query_smem_size() ---- // OPUS_H only -- runtime HIP API queries (hipGetDeviceProperties). Use when you need the true hardware value on the host (e.g. occupancy calculations). // Guarded by OPUS_ENABLE_RUNTIME_QUERY (default 0). Define OPUS_ENABLE_RUNTIME_QUERY=1 before // including opus.hpp (or via compiler flag) to enable these functions and the hip_runtime_api.h include. // OPUS_H_D constexpr index_t get_warp_size() { #if defined(__gfx1250__) return 32; #elif defined(__GFX9__) || defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || !defined(__HIP_DEVICE_COMPILE__) return 64; #else return 32; #endif } OPUS_H_D constexpr index_t get_smem_size() { #if defined(__gfx950__) return 163840; // 160KB (CDNA4) #else return 65536; // 64KB #endif } // ---- Device intrinsic wrappers ---- // Replace HIP runtime macros (threadIdx.x, __syncthreads, __all, etc.) so kernels compile // with just #include — no needed. OPUS_D index_t thread_id_x() { return __builtin_amdgcn_workitem_id_x(); } OPUS_D index_t thread_id_y() { return __builtin_amdgcn_workitem_id_y(); } OPUS_D index_t thread_id_z() { return __builtin_amdgcn_workitem_id_z(); } OPUS_D index_t block_id_x() { return __builtin_amdgcn_workgroup_id_x(); } OPUS_D index_t block_id_y() { return __builtin_amdgcn_workgroup_id_y(); } OPUS_D index_t block_id_z() { return __builtin_amdgcn_workgroup_id_z(); } OPUS_D index_t block_size_x() { return __builtin_amdgcn_workgroup_size_x(); } OPUS_D index_t block_size_y() { return __builtin_amdgcn_workgroup_size_y(); } OPUS_D index_t block_size_z() { return __builtin_amdgcn_workgroup_size_z(); } OPUS_D index_t grid_size_x() { return __builtin_amdgcn_grid_size_x(); } OPUS_D index_t grid_size_y() { return __builtin_amdgcn_grid_size_y(); } OPUS_D index_t grid_size_z() { return __builtin_amdgcn_grid_size_z(); } OPUS_D void sync_threads() { __builtin_amdgcn_s_barrier(); } #if !defined(HIP_INCLUDE_HIP_AMD_DETAIL_DEVICE_LIBRARY_DECLS_H) extern "C" __device__ int __ockl_wfall_i32(int); #endif #if !defined(HIP_INCLUDE_HIP_AMD_DETAIL_WARP_FUNCTIONS_H) OPUS_D int warp_all(int predicate) { return __ockl_wfall_i32(predicate); } #endif #if OPUS_ENABLE_RUNTIME_QUERY OPUS_H index_t query_warp_size() { int d; (void)hipGetDevice(&d); hipDeviceProp_t p; (void)hipGetDeviceProperties(&p, d); return static_cast(p.warpSize); } OPUS_H index_t query_smem_size() { int d; (void)hipGetDevice(&d); hipDeviceProp_t p; (void)hipGetDeviceProperties(&p, d); return static_cast(p.sharedMemPerBlock); } OPUS_H index_t query_num_cu() { int d; (void)hipGetDevice(&d); hipDeviceProp_t p; (void)hipGetDeviceProperties(&p, d); return static_cast(p.multiProcessorCount); } #endif // Uses compiler builtins (__builtin_amdgcn_*) instead of HIP runtime APIs, so no dependency. #ifdef __HIPCC__ struct workgroup_barrier { OPUS_D workgroup_barrier(unsigned int* ptr) : base_ptr(ptr) {} OPUS_D unsigned int ld(unsigned int offset = 0) { return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED); } OPUS_D void wait_eq(unsigned int value, unsigned int offset = 0) { if (__builtin_amdgcn_workitem_id_x() == 0) while (ld(offset) != value) {} __builtin_amdgcn_s_barrier(); } OPUS_D void wait_lt(unsigned int value, unsigned int offset = 0) { if (__builtin_amdgcn_workitem_id_x() == 0) while (ld(offset) < value) {} __builtin_amdgcn_s_barrier(); } OPUS_D void inc(unsigned int offset = 0) { __builtin_amdgcn_s_barrier(); if (__builtin_amdgcn_workitem_id_x() == 0) __atomic_fetch_add(base_ptr + offset, 1u, __ATOMIC_RELAXED); } unsigned int* base_ptr; }; #endif // NOTE: all data in unsigned int. Prefer usage, construct a mdiv structure on host, pass the structure to kernel, and use div/divmod struct mdiv { unsigned int divisor; unsigned int multiplier; unsigned int shift; OPUS_H_D mdiv() : divisor(0), multiplier(0), shift(0) {} OPUS_H_D mdiv(unsigned int divisor_) : divisor(divisor_) { unsigned int shift_u32 = 0; while ((1U << shift_u32) < divisor_) shift_u32++; unsigned long long tmp_u64 = static_cast((1UL << shift_u32) - divisor_) << 32; multiplier = static_cast(tmp_u64 / divisor_ + 1); shift = shift_u32; } // previously we use __umulhi(), which is defined in , for __device__ compilation. Today compiler is smart enough to generate s_mul_hi_u32 / v_mul_hi_u32 OPUS_H_D unsigned int div(unsigned int dividend) const { unsigned int tmp = static_cast((static_cast(dividend) * multiplier) >> 32); return (tmp + dividend) >> shift; } OPUS_H_D void divmod(unsigned int dividend, unsigned int& quotient, unsigned int& remainder) const { quotient = div(dividend); remainder = dividend - (quotient * divisor); } OPUS_H_D unsigned int get() const { return divisor; } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // math template OPUS_D T mov_dpp(T x, number, number = {}, number = {}, bool_constant = {}) { static_assert(sizeof(T) == 4); return __builtin_bit_cast(T, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, x), dpp_i, row_mask, bank_mask, bound_ctrl)); } template OPUS_D T upd_dpp(const O& old, T x, number, number = {}, number = {}, bool_constant = {}) { static_assert(sizeof(T) == 4); return __builtin_bit_cast(T, __builtin_amdgcn_update_dpp(__builtin_bit_cast(int, old), __builtin_bit_cast(int, x), dpp_i, row_mask, bank_mask, bound_ctrl)); } // lane index within wavefront (threadIdx.x % warp_size, e.g. wave64: tid=3->3, tid=70->6) OPUS_D unsigned int lane_id() { if constexpr (get_warp_size() == 32) return __builtin_amdgcn_mbcnt_lo(-1, 0); else return __builtin_amdgcn_mbcnt_hi(-1, __builtin_amdgcn_mbcnt_lo(-1, 0)); } // cross-lane shuffle via ds_bpermute (no hip_runtime.h dependency) template OPUS_D T shfl(T var, int src_lane, int width = get_warp_size()) { static_assert(sizeof(T) == 4); int self = lane_id(); int index = (src_lane & (width - 1)) + (self & ~(width - 1)); return __builtin_bit_cast(T, __builtin_amdgcn_ds_bpermute(index << 2, __builtin_bit_cast(int, var))); } template OPUS_D T max(const T&a, const T&b) { return a > b ? a : b; } template<> OPUS_D float max(const float&a, const float&b) { return __builtin_fmaxf(a, b); } template OPUS_D T min(const T&a, const T&b) { return a > b ? b : a; } template<> OPUS_D float min(const float&a, const float&b) { return __builtin_fminf(a, b); } template OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = min(a, b); return min(max_0, max(min_0, c)); } template<> OPUS_D float med3(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); } template<> OPUS_D fp16_t med3(const fp16_t&a, const fp16_t&b, const fp16_t&c) { return __builtin_amdgcn_fmed3h(a, b, c); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // buffer load/store related OPUS_D constexpr auto buffer_default_config() { #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // [DCU] verified by 3rdparty/moe_c/csrc/intrinsic_2.h and quick_all_reduce_base.h: word3 = 0x00020000 return 0x00020000; #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) return 0x00020000; #elif defined(__gfx103__) return 0x31014000; #elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) return 0x31004000; #else return 0xffffffff; #endif } OPUS_D __amdgpu_buffer_rsrc_t make_buffer_rsrc(const void* ptr, unsigned int size = 0xffffffff, unsigned int config = buffer_default_config()) { return __builtin_amdgcn_make_buffer_rsrc(const_cast(static_cast(ptr)), 0, size, config); // void *p, short stride, int num, int flags } #if __clang_major__ < 20 #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-inline" OPUS_D void llvm_amdgcn_raw_buffer_load_lds(i32x4_t r, OPUS_LDS_ADDR unsigned int* p, index_t size, index_t vos, index_t sos, index_t ios, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); #pragma clang diagnostic pop #endif #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // [DCU] BF16/i16 raw buffer load/store: __builtin_amdgcn_raw_buffer_*_b16 produces wrong bit patterns on DCU. // CK uses LLVM intrinsic alias 'i16' form (see ck_tile/core/arch/amd_buffer_addressing.hpp + CK_TILE_BUFFER_LOAD_RAW_BF16_WA). OPUS_D short llvm_amdgcn_raw_buffer_load_i16(i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); OPUS_D void llvm_amdgcn_raw_buffer_store_i16(short vdata, i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); // [DCU] Same workaround extended to b32/b64/b128 stores; observed huge garbage values (30K+) suggest // the same DCU compiler bug affects every __builtin_amdgcn_raw_buffer_store_b{32,64,128} path. OPUS_D void llvm_amdgcn_raw_buffer_store_i32 (int vdata, i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); OPUS_D void llvm_amdgcn_raw_buffer_store_v2i32(i32x2_t vdata, i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); OPUS_D void llvm_amdgcn_raw_buffer_store_v4i32(i32x4_t vdata, i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); // [DCU] Mirror load aliases for b32/b64/b128. OPUS_D int llvm_amdgcn_raw_buffer_load_i32 (i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); OPUS_D i32x2_t llvm_amdgcn_raw_buffer_load_v2i32(i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); OPUS_D i32x4_t llvm_amdgcn_raw_buffer_load_v4i32(i32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); #endif template struct gmem { using T = remove_cvref_t; using scalar_type = typename vector_traits::dtype; static constexpr index_t vector_size = vector_traits::size(); template using vector_type = vector_t; OPUS_D gmem(const void* ptr, unsigned int size = 0xffffffff, unsigned int config = buffer_default_config()) : cached_rsrc(make_buffer_rsrc(ptr, size, config)) #if defined(__gfx1250__) , raw_ptr(static_cast(ptr)) #endif {} template // os in unit of byte OPUS_D auto _load(int v_os, int s_os = 0, number = {}) { using type = vector_type; if constexpr (sizeof(type) == 1) { return __builtin_bit_cast(type, __builtin_amdgcn_raw_buffer_load_b8 (cached_rsrc, v_os, s_os, aux)); } // [BUILTIN-MHC] amdgcn_raw_buffer_load_b8 (DCU?) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) else if constexpr (sizeof(type) == 2) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); short v = llvm_amdgcn_raw_buffer_load_i16(r_, v_os, s_os, aux); return __builtin_bit_cast(type, v); } // [BUILTIN-MHC][DCU-WA] LLVM i16 alias else if constexpr (sizeof(type) == 4) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); int v = llvm_amdgcn_raw_buffer_load_i32 (r_, v_os, s_os, aux); return __builtin_bit_cast(type, v); } // [BUILTIN-MHC][DCU-WA] LLVM i32 alias else if constexpr (sizeof(type) == 8) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); i32x2_t v = llvm_amdgcn_raw_buffer_load_v2i32(r_, v_os, s_os, aux); return __builtin_bit_cast(type, v); } // [BUILTIN-MHC][DCU-WA] LLVM v2i32 alias else if constexpr (sizeof(type) == 16) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); i32x4_t v = llvm_amdgcn_raw_buffer_load_v4i32(r_, v_os, s_os, aux); return __builtin_bit_cast(type, v); } // [BUILTIN-MHC][DCU-WA] LLVM v4i32 alias #else else if constexpr (sizeof(type) == 2) { return __builtin_bit_cast(type, __builtin_amdgcn_raw_buffer_load_b16 (cached_rsrc, v_os, s_os, aux)); } // [BUILTIN-MHC] amdgcn_raw_buffer_load_b16 (DCU?) else if constexpr (sizeof(type) == 4) { return __builtin_bit_cast(type, __builtin_amdgcn_raw_buffer_load_b32 (cached_rsrc, v_os, s_os, aux)); } // [BUILTIN-MHC] amdgcn_raw_buffer_load_b32 (DCU?) else if constexpr (sizeof(type) == 8) { return __builtin_bit_cast(type, __builtin_amdgcn_raw_buffer_load_b64 (cached_rsrc, v_os, s_os, aux)); } // [BUILTIN-MHC] amdgcn_raw_buffer_load_b64 (DCU?) else if constexpr (sizeof(type) == 16) { return __builtin_bit_cast(type, __builtin_amdgcn_raw_buffer_load_b128(cached_rsrc, v_os, s_os, aux)); } // [BUILTIN-MHC] amdgcn_raw_buffer_load_b128 (DCU?) #endif } template // os in unit of byte OPUS_D void _async_load(OPUS_LDS_ADDR void* dst, int v_os, int s_os = 0, number = {}) { using type = vector_type; #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // [DCU] hw buffer_load_lds writes all 64 lanes' results contiguously starting at m0 // (i.e. LDS[m0 + lane*size]); it cannot deposit per-lane vectors >4B because the // destination is forced to stride-4 across lanes. For vec sizes 1/2/4 use the LLVM // intrinsic; for larger vec, emulate via multiple b32 buffer_load_lds (similar to // the upstream gfx9 fallback below). i32x4_t cached_rsrc_; __builtin_memcpy(&cached_rsrc_, &cached_rsrc, sizeof(i32x4_t)); if constexpr (sizeof(type) == 1 || sizeof(type) == 2 || sizeof(type) == 4) { auto* d4 = reinterpret_cast(dst); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4, sizeof(type), v_os, s_os, 0, aux); } else if constexpr (sizeof(type) == 8) { auto* d4 = reinterpret_cast(dst); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 0, 4, v_os + 0, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 1, 4, v_os + 4, s_os, 0, aux); } else if constexpr (sizeof(type) == 16) { auto* d4 = reinterpret_cast(dst); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 0, 4, v_os + 0, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 1, 4, v_os + 4, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 2, 4, v_os + 8, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 3, 4, v_os + 12, s_os, 0, aux); } #elif defined(__gfx1250__) // gfx1250: global_load_async_to_lds (global addressing, not buffer rsrc) #define GPTR_(T, p) ((__attribute__((address_space(1))) T*)(p)) #define LPTR_(T, p) ((OPUS_LDS_ADDR T*)(p)) { auto* src = raw_ptr + v_os + s_os; if constexpr (sizeof(type) == 1) { __builtin_amdgcn_global_load_async_to_lds_b8 (GPTR_(char, src), LPTR_(char, dst), 0, 0); } else if constexpr (sizeof(type) == 2) { __builtin_amdgcn_global_load_async_to_lds_b8 (GPTR_(char, src), LPTR_(char, dst), 0, 0); __builtin_amdgcn_global_load_async_to_lds_b8 (GPTR_(char, src + 1), LPTR_(char, (char*)dst + 1), 0, 0); } else if constexpr (sizeof(type) == 4) { __builtin_amdgcn_global_load_async_to_lds_b32 (GPTR_(int, src), LPTR_(int, dst), 0, 0); } else if constexpr (sizeof(type) == 8) { __builtin_amdgcn_global_load_async_to_lds_b64 (GPTR_(i32x2_t, src), LPTR_(i32x2_t, dst), 0, 0); } else if constexpr (sizeof(type) == 16) { __builtin_amdgcn_global_load_async_to_lds_b128(GPTR_(i32x4_t, src), LPTR_(i32x4_t, dst), 0, 0); } } #undef GPTR_ #undef LPTR_ #elif __clang_major__ >= 20 // start from rocm 7.0,introduced by https://github.com/llvm/llvm-project/pull/132048, 133055, 132957 if constexpr (sizeof(type) == 1) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 1, v_os, s_os, 0, aux); } else if constexpr (sizeof(type) == 2) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 2, v_os, s_os, 0, aux); } else if constexpr (sizeof(type) == 4) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 4, v_os, s_os, 0, aux); } #if defined(__gfx950__) else if constexpr (sizeof(type) == 12) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 12, v_os, s_os, 0, aux); } else if constexpr (sizeof(type) == 16) { __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, dst, 16, v_os, s_os, 0, aux); } #elif defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__) // [DCU/gfx9] hw buffer_load_lds only supports 1/2/4 bytes; emulate larger via multiple b32 loads. else if constexpr (sizeof(type) == 8) { auto* d4 = reinterpret_cast(dst); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 0, 4, v_os + 0, s_os, 0, aux); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 1, 4, v_os + 4, s_os, 0, aux); } else if constexpr (sizeof(type) == 12) { auto* d4 = reinterpret_cast(dst); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 0, 4, v_os + 0, s_os, 0, aux); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 1, 4, v_os + 4, s_os, 0, aux); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 2, 4, v_os + 8, s_os, 0, aux); } else if constexpr (sizeof(type) == 16) { auto* d4 = reinterpret_cast(dst); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 0, 4, v_os + 0, s_os, 0, aux); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 1, 4, v_os + 4, s_os, 0, aux); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 2, 4, v_os + 8, s_os, 0, aux); __builtin_amdgcn_raw_ptr_buffer_load_lds(cached_rsrc, d4 + 3, 4, v_os + 12, s_os, 0, aux); } #endif #else i32x4_t cached_rsrc_; __builtin_memcpy(&cached_rsrc_, &cached_rsrc, sizeof(i32x4_t)); // builtin memcpy, __builtin_bit_cast() can not use here due to __amdgpu_buffer_rsrc_t is non copyable if constexpr (sizeof(type) == 1) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast(dst), 1, v_os, s_os, 0, aux); } // [BUILTIN-MHC] llvm.amdgcn.raw.buffer.load.lds intrinsic (DCU? -- actual async LDS load on older clang) else if constexpr (sizeof(type) == 2) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast(dst), 2, v_os, s_os, 0, aux); } // [BUILTIN-MHC] llvm.amdgcn.raw.buffer.load.lds (DCU?) else if constexpr (sizeof(type) == 4) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast(dst), 4, v_os, s_os, 0, aux); } // [BUILTIN-MHC] llvm.amdgcn.raw.buffer.load.lds (DCU?) #if defined(__gfx950__) else if constexpr (sizeof(type) == 12) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast(dst), 12, v_os, s_os, 0, aux); } else if constexpr (sizeof(type) == 16) {llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, reinterpret_cast(dst), 16, v_os, s_os, 0, aux); } #elif defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__) // [DCU/gfx9] hw buffer_load_lds only supports 1/2/4 bytes; emulate larger via multiple b32 loads. else if constexpr (sizeof(type) == 8) { auto* d4 = reinterpret_cast(dst); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 0, 4, v_os + 0, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 1, 4, v_os + 4, s_os, 0, aux); } else if constexpr (sizeof(type) == 12) { auto* d4 = reinterpret_cast(dst); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 0, 4, v_os + 0, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 1, 4, v_os + 4, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 2, 4, v_os + 8, s_os, 0, aux); } else if constexpr (sizeof(type) == 16) { auto* d4 = reinterpret_cast(dst); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 0, 4, v_os + 0, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 1, 4, v_os + 4, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 2, 4, v_os + 8, s_os, 0, aux); llvm_amdgcn_raw_buffer_load_lds(cached_rsrc_, d4 + 3, 4, v_os + 12, s_os, 0, aux); } #endif #endif } template // os in unit of byte OPUS_D void _store(const V& x, int v_os, int s_os = 0, number = {}) { static_assert((vec * vector_size) == vector_traits::size(), "vector size need to be same, please check"); if constexpr (sizeof(vector_type) == 1) { __builtin_amdgcn_raw_buffer_store_b8 (__builtin_bit_cast(i8_t, x), cached_rsrc, v_os, s_os, aux); } // [BUILTIN-MHC] amdgcn_raw_buffer_store_b8 (DCU?) #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) else if constexpr (sizeof(vector_type) == 2) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); llvm_amdgcn_raw_buffer_store_i16(__builtin_bit_cast(short, x), r_, v_os, s_os, aux); } // [BUILTIN-MHC][DCU-WA] LLVM i16 alias else if constexpr (sizeof(vector_type) == 4) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); llvm_amdgcn_raw_buffer_store_i32 (__builtin_bit_cast(int, x), r_, v_os, s_os, aux); } // [BUILTIN-MHC][DCU-WA] LLVM i32 alias else if constexpr (sizeof(vector_type) == 8) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); llvm_amdgcn_raw_buffer_store_v2i32(__builtin_bit_cast(i32x2_t, x), r_, v_os, s_os, aux); } // [BUILTIN-MHC][DCU-WA] LLVM v2i32 alias else if constexpr (sizeof(vector_type) == 16) { i32x4_t r_; __builtin_memcpy(&r_, &cached_rsrc, sizeof(i32x4_t)); llvm_amdgcn_raw_buffer_store_v4i32(__builtin_bit_cast(i32x4_t, x), r_, v_os, s_os, aux); } // [BUILTIN-MHC][DCU-WA] LLVM v4i32 alias #else else if constexpr (sizeof(vector_type) == 2) { __builtin_amdgcn_raw_buffer_store_b16 (__builtin_bit_cast(i16_t, x), cached_rsrc, v_os, s_os, aux); } // [BUILTIN-MHC] amdgcn_raw_buffer_store_b16 (DCU?) else if constexpr (sizeof(vector_type) == 4) { __builtin_amdgcn_raw_buffer_store_b32 (__builtin_bit_cast(i32_t, x), cached_rsrc, v_os, s_os, aux); } // [BUILTIN-MHC] amdgcn_raw_buffer_store_b32 (DCU?) else if constexpr (sizeof(vector_type) == 8) { __builtin_amdgcn_raw_buffer_store_b64 (__builtin_bit_cast(i32x2_t, x), cached_rsrc, v_os, s_os, aux); } // [BUILTIN-MHC] amdgcn_raw_buffer_store_b64 (DCU?) else if constexpr (sizeof(vector_type) == 16) { __builtin_amdgcn_raw_buffer_store_b128(__builtin_bit_cast(i32x4_t, x), cached_rsrc, v_os, s_os, aux); } // [BUILTIN-MHC] amdgcn_raw_buffer_store_b128 (DCU?) #endif } template // os in unit of T and cast to vector with vec OPUS_D auto load(int v_os, int s_os = 0, number = {}) { return _load(v_os * sizeof(T), s_os * sizeof(T), number{}); } template // os in unit of T and cast to vector with vec OPUS_D void async_load(void* dst, int v_os, int s_os = 0, number = {}) { _async_load(reinterpret_cast(reinterpret_cast<__UINTPTR_TYPE__>(dst)), v_os * sizeof(T), s_os * sizeof(T), number{}); } template || is_dtype_v || is_array_v), bool> = true> // os in unit of T and cast to vector with vec OPUS_D void store(const V& x, int v_os, int s_os = 0, number = {}) { static_assert(std::is_same_v::dtype, scalar_type>, "scalar type must be same for the data to be stored" ); if constexpr (is_dtype_v && (vec * vector_size) % vector_traits::size() == 0) { _store(make_repeated_vector(x, number::size()>{}), v_os * sizeof(T), s_os * sizeof(T), number{}); } else { static_assert((vec * vector_size) == vector_traits::size(), "vector size need to be same, please check" ); _store(x, v_os * sizeof(T), s_os * sizeof(T), number{}); } } // bulk load API, give me a Shape of this tile, will issue multiple load instruction based on the y-shape space template, bool> = true> OPUS_D auto load(const Layout& u, int s_os = 0/* do we really need this? */, number = {}) { using LT = layout_load_traits; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); #if OPUS_TILE_CONTAINER == 0 vector_t r; for (index_t i = 0; i < r_elem.value; i++) { auto tmp = load(offsets[i], s_os, number{}); for (index_t j = 0; j < vec * vector_size; j++) r[i * vec * vector_size + j] = tmp[j]; } return r; #elif OPUS_TILE_CONTAINER == 1 array, r_elem.value> r; for (index_t i = 0; i < r_elem.value; i++) r[i] = load(offsets[i], s_os, number{}); return r; #endif } template || is_vector_v) && is_layout_v), bool> = true> OPUS_D void store(const V& x, const Layout& u, int s_os = 0/* do we really need this? */, number = {}) { using LT = layout_load_traits; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); #if OPUS_TILE_CONTAINER == 0 auto a_ = [&](){ if constexpr (is_array_v) return to_vector(x); else if constexpr (is_dtype_v) return make_repeated_vector(x, number{}); else if constexpr (is_vector_v) return x; }(); #elif OPUS_TILE_CONTAINER == 1 auto a_ = to_array(x); #endif for (index_t i = 0; i < r_elem.value; i++) { vector_type v_; for (index_t j = 0; j < vec * vector_size; j++) v_[j] = a_[i * vec * vector_size + j]; store(v_, offsets[i], s_os, number{}); } } template && is_layout_v, bool> = true> OPUS_D void async_load(void* smem_base, const LayoutG& u_gmem, const LayoutS& u_smem, int s_os = 0, number = {}) { using LT = layout_load_traits; constexpr auto r_elem = LT::r_elem; auto gmem_offsets = layout_to_offsets(u_gmem); auto smem_offsets = layout_to_offsets(u_smem); auto smem_ptr = reinterpret_cast(reinterpret_cast<__UINTPTR_TYPE__>(smem_base)); for (index_t i = 0; i < r_elem.value; i++) { async_load(reinterpret_cast(reinterpret_cast<__UINTPTR_TYPE__>(smem_ptr + smem_offsets[i])), gmem_offsets[i], s_os, number{}); } } template, bool> = true> OPUS_D auto load_if(const Predicate& pred, const Layout& u, int s_os = 0, number = {}) { using LT = layout_load_traits; constexpr auto issue_space = LT::issue_space; constexpr auto issue_space_vec = LT::issue_space_vec; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); constexpr auto u_r = make_layout<-1>(issue_space_vec); #if OPUS_TILE_CONTAINER == 0 vector_t r; static_ford(issue_space_vec, [&](auto ... ids){ constexpr index_t idx = u_r(ids...); auto tmp = pred(ids...) ? load(offsets[idx], s_os, number{}) : vector_type{0}; set_slice(r, tmp, number{}, number<(idx + 1) * vec>{}); }); return r; #elif OPUS_TILE_CONTAINER == 1 array, r_elem.value> r; static_ford(issue_space_vec, [&](auto ... ids){ r[u_r(ids...)] = pred(ids...) ? load(offsets[u_r(ids...)], s_os, number{}) : vector_type{0}; }); return r; #endif } template || is_vector_v) && is_layout_v), bool> = true> OPUS_D void store_if(const Predicate& pred, const V& x, const Layout& u, int s_os = 0, number = {}) { using LT = layout_load_traits; constexpr auto issue_space = LT::issue_space; constexpr auto issue_space_vec = LT::issue_space_vec; auto offsets = layout_to_offsets(u); constexpr auto u_r = make_layout<-1>(issue_space); #if OPUS_TILE_CONTAINER == 0 auto a_ = [&](){ if constexpr (is_array_v) return to_vector(x); else if constexpr (is_dtype_v) return make_repeated_vector(x, number(reduce_tuple_mul(issue_space)).value>{}); else if constexpr (is_vector_v) return x; }(); #elif OPUS_TILE_CONTAINER == 1 auto a_ = to_array(x); #endif static_ford(issue_space_vec, [&](auto ... ids){ if (pred(ids...)) { constexpr index_t idx = u_r(ids...); auto v_ = slice(a_, number{}, number{}); store(v_, offsets[make_layout<-1>(issue_space_vec)(ids...)], s_os, number{}); } }); } template && is_layout_v, bool> = true> OPUS_D void async_load_if(const Predicate& pred, void* smem_base, const LayoutG& u_gmem, const LayoutS& u_smem, int s_os = 0, number = {}) { using LT = layout_load_traits; constexpr auto issue_space_vec = LT::issue_space_vec; auto gmem_offsets = layout_to_offsets(u_gmem); auto smem_offsets = layout_to_offsets(u_smem); auto smem_ptr = reinterpret_cast(reinterpret_cast<__UINTPTR_TYPE__>(smem_base)); constexpr auto u_r = make_layout<-1>(issue_space_vec); static_ford(issue_space_vec, [&](auto... ids) { constexpr index_t idx = u_r(ids...); if (pred(ids...)) { async_load(reinterpret_cast(reinterpret_cast<__UINTPTR_TYPE__>(smem_ptr + smem_offsets[idx])), gmem_offsets[idx], s_os, number{}); } else { using type = vector_type; type z = {0}; *reinterpret_cast(smem_ptr + smem_offsets[idx]) = z; } }); } __amdgpu_buffer_rsrc_t cached_rsrc; #if defined(__gfx1250__) const char* raw_ptr; // flat pointer for global_load_async_to_lds (gfx1250 uses global addressing, not buffer rsrc) #endif }; template OPUS_D decltype(auto) make_gmem(const T_* ptr, unsigned int size = 0xffffffff, unsigned int config = buffer_default_config()) { return gmem{ptr, size, config}; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // smem load/store related template struct smem { using T = remove_cvref_t; using scalar_type = typename vector_traits::dtype; static constexpr index_t vector_size = vector_traits::size(); template using vector_type = vector_t; OPUS_D smem(void* ptr_) : ptr(reinterpret_cast(reinterpret_cast<__UINTPTR_TYPE__>(ptr_))) {} template OPUS_D auto _load(int v_os/* in unit of byte*/) { using type = vector_type; return *reinterpret_cast(ptr + v_os); } template OPUS_D auto _tr_load(int v_os/* in unit of byte*/) { #if defined(__HIP_DEVICE_COMPILE__) && defined(__gfx950__) using type = vector_type; static_assert(sizeof(type) == 8, "DS_READ_B64_TR requires 8-byte (64-bit) load"); constexpr index_t elem_bits = sizeof_bits_v; i32x2_t raw; const u32_t addr = static_cast(reinterpret_cast<__UINTPTR_TYPE__>(ptr + v_os)); if constexpr (elem_bits == 16) { asm volatile("ds_read_b64_tr_b16 %0, %1 offset:%2\n" : "=v"(raw) : "v"(addr), "i"(imm_offset) : "memory"); } else if constexpr (elem_bits == 8) { asm volatile("ds_read_b64_tr_b8 %0, %1 offset:%2\n" : "=v"(raw) : "v"(addr), "i"(imm_offset) : "memory"); } else if constexpr (elem_bits == 4) { asm volatile("ds_read_b64_tr_b4 %0, %1 offset:%2\n" : "=v"(raw) : "v"(addr), "i"(imm_offset) : "memory"); } else { static_assert(sizeof(T_) == 0, "smem::_tr_load: unsupported scalar type"); } return __builtin_bit_cast(type, raw); #elif defined(__HIP_DEVICE_COMPILE__) static_assert(sizeof(T_) == 0, "smem::_tr_load requires __gfx950__"); return _load(v_os + imm_offset); #else return _load(v_os + imm_offset); #endif } template OPUS_D void _store(const V& x, int v_os/* in unit of byte*/) { static_assert((vec * vector_size) == vector_traits::size(), "vector size need to be same, please check"); using type = vector_type; *reinterpret_cast(ptr + v_os) = __builtin_bit_cast(type, x); } template OPUS_D auto load(int v_os) { return _load(v_os * sizeof(T)); } template OPUS_D auto tr_load(int v_os) { return _tr_load(v_os * sizeof(T)); } template || is_dtype_v || is_array_v), bool> = true> OPUS_D void store(const V& x, int v_os) { static_assert(std::is_same_v::dtype, scalar_type>, "scalar type must be same for the data to be stored" ); if constexpr (is_dtype_v && (vec * vector_size) % vector_traits::size() == 0) { _store(make_repeated_vector(x, number::size()>{}), v_os * sizeof(T)); } else { static_assert((vec * vector_size) == vector_traits::size(), "vector size need to be same, please check" ); _store(x, v_os * sizeof(T)); } } // bulk load API, give me a Shape of this tile, will issue multiple load instruction based on the y-shape space template, bool> = true> OPUS_D auto load(const Layout& u) { using LT = layout_load_traits; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); #if OPUS_TILE_CONTAINER == 0 vector_t r; for (index_t i = 0; i < r_elem.value; i++) { auto tmp = load(offsets[i]); for (index_t j = 0; j < vec * vector_size; j++) r[i * vec * vector_size + j] = tmp[j]; } return r; #elif OPUS_TILE_CONTAINER == 1 array, r_elem.value> r; for (index_t i = 0; i < r_elem.value; i++) r[i] = load(offsets[i]); return r; #endif } template, bool> = true> OPUS_D auto tr_load(const Layout& u) { using LT = layout_load_traits; constexpr auto r_elem = LT::r_elem; using L = remove_cvref_t; constexpr bool use_imm = is_static_tuple_v && is_static_tuple_v; [[maybe_unused]] const int base = u(transform_tuple([](auto) { return number<0>{}; }, LT::issue_space_vec)) * sizeof(T); [[maybe_unused]] auto offsets = layout_to_offsets(u); auto do_load = [&](auto i) { if constexpr (use_imm) { using IMM = layout_imm_offsets; constexpr int off = IMM::offsets[i.value] * sizeof(T); if constexpr (off >= 0 && off <= 0xffff) { return _tr_load(base); } } return tr_load(offsets[i.value]); }; #if OPUS_TILE_CONTAINER == 0 vector_t r; static_for([&](auto i){ set_slice(r, do_load(i), number{}, number<(i.value + 1) * vec>{}); }); return r; #elif OPUS_TILE_CONTAINER == 1 array, r_elem.value> r; static_for([&](auto i){ r[i.value] = do_load(i); }); return r; #endif } template || is_dtype_v || is_vector_v) && is_layout_v), bool> = true> OPUS_D void store(const V& x, const Layout& u) { using LT = layout_load_traits; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); #if OPUS_TILE_CONTAINER == 0 auto a_ = [&](){ if constexpr (is_array_v) return to_vector(x); else if constexpr (is_dtype_v) return make_repeated_vector(x, number{}); else if constexpr (is_vector_v) return x; }(); #elif OPUS_TILE_CONTAINER == 1 auto a_ = to_array(x); #endif for (index_t i = 0; i < r_elem.value; i++) { vector_type v_; for (index_t j = 0; j < vec * vector_size; j++) v_[j] = a_[i * vec * vector_size + j]; store(v_, offsets[i]); } } template, bool> = true> OPUS_D auto load_if(const Predicate& pred, const Layout& u) { using LT = layout_load_traits; constexpr auto issue_space_vec = LT::issue_space_vec; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); constexpr auto u_r = make_layout<-1>(issue_space_vec); #if OPUS_TILE_CONTAINER == 0 vector_t r; static_ford(issue_space_vec, [&](auto ... ids){ constexpr index_t idx = u_r(ids...); auto tmp = pred(ids...) ? load(offsets[idx]) : vector_type{0}; set_slice(r, tmp, number{}, number<(idx + 1) * vec>{}); }); return r; #elif OPUS_TILE_CONTAINER == 1 array, r_elem.value> r; static_ford(issue_space_vec, [&](auto ... ids){ r[u_r(ids...)] = pred(ids...) ? load(offsets[u_r(ids...)]) : vector_type{0}; }); return r; #endif } template, bool> = true> OPUS_D auto tr_load_if(const Predicate& pred, const Layout& u) { using LT = layout_load_traits; constexpr auto issue_space_vec = LT::issue_space_vec; constexpr auto r_elem = LT::r_elem; auto offsets = layout_to_offsets(u); constexpr auto u_r = make_layout<-1>(issue_space_vec); #if OPUS_TILE_CONTAINER == 0 vector_t r; static_ford(issue_space_vec, [&](auto ... ids){ constexpr index_t idx = u_r(ids...); auto tmp = pred(ids...) ? tr_load(offsets[idx]) : vector_type{0}; set_slice(r, tmp, number{}, number<(idx + 1) * vec>{}); }); return r; #elif OPUS_TILE_CONTAINER == 1 array, r_elem.value> r; static_ford(issue_space_vec, [&](auto ... ids){ r[u_r(ids...)] = pred(ids...) ? tr_load(offsets[u_r(ids...)]) : vector_type{0}; }); return r; #endif } template || is_dtype_v || is_vector_v) && is_layout_v), bool> = true> OPUS_D void store_if(const Predicate& pred, const V& x, const Layout& u) { using LT = layout_load_traits; constexpr auto issue_space = LT::issue_space; constexpr auto issue_space_vec = LT::issue_space_vec; auto offsets = layout_to_offsets(u); constexpr auto u_r = make_layout<-1>(issue_space); #if OPUS_TILE_CONTAINER == 0 auto a_ = [&](){ if constexpr (is_array_v) return to_vector(x); else if constexpr (is_dtype_v) return make_repeated_vector(x, number(reduce_tuple_mul(issue_space)).value>{}); else if constexpr (is_vector_v) return x; }(); #elif OPUS_TILE_CONTAINER == 1 auto a_ = to_array(x); #endif static_ford(issue_space_vec, [&](auto ... ids){ if (pred(ids...)) { constexpr index_t idx = u_r(ids...); auto v_ = slice(a_, number{}, number{}); store(v_, offsets[make_layout<-1>(issue_space_vec)(ids...)]); } }); } OPUS_LDS_ADDR char* ptr; // in unit of byte }; template OPUS_D decltype(auto) make_smem(T_* ptr) { return smem{ptr}; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // mem type traits & free function wrappers (eliminate .template syntax in dependent context) template struct is_gmem : false_type {}; template struct is_gmem> : true_type {}; template constexpr bool is_gmem_v = is_gmem>::value; template struct is_smem : false_type {}; template struct is_smem> : true_type {}; template constexpr bool is_smem_v = is_smem>::value; template constexpr bool is_mem_v = is_gmem_v || is_smem_v; template, bool> = true> OPUS_D auto load(Mem& mem, Args&&... args) { return mem.template load(std::forward(args)...); } template, bool> = true> OPUS_D void store(Mem& mem, Args&&... args) { mem.template store(std::forward(args)...); } template, bool> = true> OPUS_D void async_load(Mem& mem, Args&&... args) { mem.template async_load(std::forward(args)...); } template, bool> = true> OPUS_D auto load_if(Mem& mem, Args&&... args) { return mem.template load_if(std::forward(args)...); } template, bool> = true> OPUS_D auto tr_load(Mem& mem, Args&&... args) { return mem.template tr_load(std::forward(args)...); } template, bool> = true> OPUS_D auto tr_load_if(Mem& mem, Args&&... args) { return mem.template tr_load_if(std::forward(args)...); } template, bool> = true> OPUS_D void store_if(Mem& mem, Args&&... args) { mem.template store_if(std::forward(args)...); } template, bool> = true> OPUS_D void async_load_if(Mem& mem, Args&&... args) { mem.template async_load_if(std::forward(args)...); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // waitcnt #if defined(__gfx1250__) // gfx1250: split wait counters, exposed as native instruction wrappers via LLVM IR intrinsics. // s_wait_expcnt/s_wait_samplecnt/s_wait_bvhcnt do NOT exist on gfx1250. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-inline" OPUS_D void llvm_s_wait_loadcnt(short cnt) __asm("llvm.amdgcn.s.wait.loadcnt"); OPUS_D void llvm_s_wait_dscnt(short cnt) __asm("llvm.amdgcn.s.wait.dscnt"); OPUS_D void llvm_s_wait_storecnt(short cnt) __asm("llvm.amdgcn.s.wait.storecnt"); OPUS_D void llvm_s_wait_kmcnt(short cnt) __asm("llvm.amdgcn.s.wait.kmcnt"); OPUS_D void llvm_s_wait_asynccnt(short cnt) __asm("llvm.amdgcn.s.wait.asynccnt"); OPUS_D void llvm_s_wait_tensorcnt(short cnt) __asm("llvm.amdgcn.s.wait.tensorcnt"); #pragma clang diagnostic pop template OPUS_D void s_wait_loadcnt(number = {}) { llvm_s_wait_loadcnt(cnt); } template OPUS_D void s_wait_dscnt(number = {}) { llvm_s_wait_dscnt(cnt); } template OPUS_D void s_wait_storecnt(number = {}) { llvm_s_wait_storecnt(cnt); } template OPUS_D void s_wait_kmcnt(number = {}) { llvm_s_wait_kmcnt(cnt); } template OPUS_D void s_wait_asynccnt(number = {}) { llvm_s_wait_asynccnt(cnt); } template OPUS_D void s_wait_tensorcnt(number = {}) { llvm_s_wait_tensorcnt(cnt); } #else // gfx9: combined s_waitcnt instruction template OPUS_D void s_waitcnt(number, number, number = {}) { __builtin_amdgcn_s_waitcnt((((0b110000 & vmcnt) << (14 - 4)) | (0b1111 & vmcnt)) | ((0b111 & expcnt) << 4) | ((0b1111 & lgkmcnt) << 8)); } // [BUILTIN-MHC] amdgcn_s_waitcnt (DCU?) -- gfx9 combined waitcnt #if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) // [DCU] buffer_load_lds LDS-commit is tracked under lgkmcnt, not vmcnt (unlike upstream gfx9). // Without lgkmcnt(0) the next LDS read may see stale data. Make s_waitcnt_vmcnt also flush LDS. template OPUS_D void s_waitcnt_vmcnt(number) { s_waitcnt(number{}, number<0>{}); } #else template OPUS_D void s_waitcnt_vmcnt(number) { s_waitcnt(number{}, number<15>{}); } #endif template OPUS_D void s_waitcnt_lgkmcnt(number) { s_waitcnt(number<63>{}, number{}); } #endif // Helper: resolve vtype for MFMA/WMMA registers. Packed types (fp4_t etc.) use underlying storage since ext_vector_type requires scalar types. namespace impl { template struct mfma_vtype { using type = vector_t; }; template struct mfma_vtype>> { using type = vector_t; }; } template using mfma_vtype_t = typename impl::mfma_vtype::type; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // mfma (GFX9: gfx942, gfx950) #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) #define DISPATCH_MFMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { return inst_(a, b, c, cbsz, abid, blgp); } #define DISPATCH_MFMA_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, inst_k_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ constexpr index_t steps = wk_ / inst_k_; constexpr index_t e_a = elem_a / steps; constexpr index_t e_b = elem_b / steps; \ auto tmp = inst_(slice(a, number<0>{}, number{}), slice(b, number<0>{}, number{}), c, cbsz, abid, blgp); \ static_for([&](auto i){ tmp = inst_(slice(a, number{}, number{}), slice(b, number{}, number{}), tmp, cbsz, abid, blgp); }); \ return tmp; } // f32 MFMA: inputs are scalar floats (elem_a = elem_b = 1), extract via [0] from vector_t #define DISPATCH_MFMA_F32_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { return inst_(a[0], b[0], c, cbsz, abid, blgp); } // gfx942 _1k bf16 intrinsics require short vectors; bitcast bf16 -> short before calling #define DISPATCH_MFMA_BF16_1K_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ using _sa = short __attribute__((ext_vector_type(elem_a))); using _sb = short __attribute__((ext_vector_type(elem_b))); \ return inst_(__builtin_bit_cast(_sa, a), __builtin_bit_cast(_sb, b), c, cbsz, abid, blgp); } #define DISPATCH_MFMA_STEP_K_BF16_1K_(ta_, tb_, tc_, wm_, wn_, wk_, inst_k_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ constexpr index_t steps = wk_ / inst_k_; constexpr index_t e_a = elem_a / steps; constexpr index_t e_b = elem_b / steps; \ using _sa = short __attribute__((ext_vector_type(e_a))); using _sb = short __attribute__((ext_vector_type(e_b))); \ auto tmp = inst_(__builtin_bit_cast(_sa, slice(a, number<0>{}, number{})), __builtin_bit_cast(_sb, slice(b, number<0>{}, number{})), c, cbsz, abid, blgp); \ static_for([&](auto i){ tmp = inst_(__builtin_bit_cast(_sa, slice(a, number{}, number{})), __builtin_bit_cast(_sb, slice(b, number{}, number{})), tmp, cbsz, abid, blgp); }); \ return tmp; } // fp8/bf8 intrinsics expect packed long (8 x 8-bit = 64-bit); bitcast ext_vector -> long #define DISPATCH_MFMA_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ return inst_(__builtin_bit_cast(long, a), __builtin_bit_cast(long, b), c, cbsz, abid, blgp); } // scaled MFMA (f8f6f4): input always bitcast to i32x8_t (256 bits); uses format codes and runtime scale #define DISPATCH_MFMA_SCALE_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ return inst_(__builtin_bit_cast(i32x8_t, a), __builtin_bit_cast(i32x8_t, b), c, fmt_a, fmt_b, 0, scale_a, 0, scale_b); } // prefer use make_mfma() to create instance, which will return impl::mfma_adaptor_xxx. In this way we can access layout info from the "mma" // // Scaled MFMA (gfx950: __builtin_amdgcn_mfma_scale_f32_{32x32x64,16x16x128}_f8f6f4) // is also dispatched from this struct via the operator()(a, b, c, int scale_a, int scale_b) overload. // Input registers are always 256 bits (i32x8_t) regardless of element type; bitcast is done internally. // Format codes (Atype / Btype): 0=fp8(E4M3), 1=bf8(E5M2), 2=fp6(E2M3), 3=bf6(E3M2), 4=fp4(E2M1) // scale_a, scale_b: E8M0 exponent values (int); actual_scale = 2^(value - 127). Use 127 for no scaling. template struct mfma { using dtype_a = remove_cvref_t; using dtype_b = remove_cvref_t; using dtype_c = remove_cvref_t; static constexpr index_t wave_m = wave_m_; static constexpr index_t wave_n = wave_n_; static constexpr index_t wave_k = wave_k_; static constexpr index_t warp_size = warp_size_; static constexpr index_t elem_a = wave_m * wave_k / warp_size; static constexpr index_t elem_b = wave_n * wave_k / warp_size; static constexpr index_t elem_c = wave_m * wave_n / warp_size; using vtype_a = mfma_vtype_t; using vtype_b = mfma_vtype_t; using vtype_c = vector_t; // Format code for scaled MFMA (f8f6f4); -1 for types that don't support scaling static constexpr int fmt_a = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; static constexpr int fmt_b = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; // Regular MFMA dispatch (cbsz/abid/blgp are compile-time parameters) template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, number = {}, number = {}, number = {}) -> vtype_c { (void)a; (void)b; (void)c; // used by DISPATCH_MFMA_ macros; suppress -Wunused-parameter on host if constexpr (false) {} // in case of macro not defined #if defined(__gfx942__) || defined(__gfx9_4_generic__) || defined(__gfx950__) else if constexpr DISPATCH_MFMA_(fp16_t, fp16_t, fp32_t, 32, 32, 8, __builtin_amdgcn_mfma_f32_32x32x8f16) else if constexpr DISPATCH_MFMA_(fp16_t, fp16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_mfma_f32_16x16x16f16) else if constexpr DISPATCH_MFMA_BF16_1K_(bf16_t, bf16_t, fp32_t, 32, 32, 8, __builtin_amdgcn_mfma_f32_32x32x8bf16_1k) else if constexpr DISPATCH_MFMA_BF16_1K_(bf16_t, bf16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_mfma_f32_16x16x16bf16_1k) else if constexpr DISPATCH_MFMA_8BIT_(fp8_t , fp8_t , fp32_t, 32, 32, 16, __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8) else if constexpr DISPATCH_MFMA_8BIT_(fp8_t , fp8_t , fp32_t, 16, 16, 32, __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8) else if constexpr DISPATCH_MFMA_8BIT_(bf8_t , bf8_t , fp32_t, 32, 32, 16, __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8) else if constexpr DISPATCH_MFMA_8BIT_(bf8_t , bf8_t , fp32_t, 16, 16, 32, __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8) else if constexpr DISPATCH_MFMA_F32_(fp32_t, fp32_t, fp32_t, 32, 32, 2, __builtin_amdgcn_mfma_f32_32x32x2f32) else if constexpr DISPATCH_MFMA_F32_(fp32_t, fp32_t, fp32_t, 16, 16, 4, __builtin_amdgcn_mfma_f32_16x16x4f32) #endif #if defined(__gfx942__) || defined(__gfx9_4_generic__) else if constexpr DISPATCH_MFMA_STEP_K_(fp16_t, fp16_t, fp32_t, 32, 32, 16, 8, __builtin_amdgcn_mfma_f32_32x32x8f16) else if constexpr DISPATCH_MFMA_STEP_K_(fp16_t, fp16_t, fp32_t, 16, 16, 32, 16, __builtin_amdgcn_mfma_f32_16x16x16f16) else if constexpr DISPATCH_MFMA_STEP_K_BF16_1K_(bf16_t, bf16_t, fp32_t, 32, 32, 16, 8, __builtin_amdgcn_mfma_f32_32x32x8bf16_1k) else if constexpr DISPATCH_MFMA_STEP_K_BF16_1K_(bf16_t, bf16_t, fp32_t, 16, 16, 32, 16, __builtin_amdgcn_mfma_f32_16x16x16bf16_1k) #endif #if defined(__gfx950__) else if constexpr DISPATCH_MFMA_(fp16_t, fp16_t, fp32_t, 32, 32, 16, __builtin_amdgcn_mfma_f32_32x32x16_f16) else if constexpr DISPATCH_MFMA_(fp16_t, fp16_t, fp32_t, 16, 16, 32, __builtin_amdgcn_mfma_f32_16x16x32_f16) else if constexpr DISPATCH_MFMA_(bf16_t, bf16_t, fp32_t, 32, 32, 16, __builtin_amdgcn_mfma_f32_32x32x16_bf16) else if constexpr DISPATCH_MFMA_(bf16_t, bf16_t, fp32_t, 16, 16, 32, __builtin_amdgcn_mfma_f32_16x16x32_bf16) #endif __builtin_unreachable(); // supprize warning for return type deduction } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, number = {}, number = {}, number = {}) { vtype_c c{0}; return operator()(a, b, c, number{}, number{}, number{}); } // Scaled MFMA dispatch (gfx950: f8f6f4 with E8M0 block exponent scaling) // scale_a, scale_b are runtime E8M0 exponent values; 127 = no scaling (2^0 = 1.0). template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) -> vtype_c { (void)a; (void)b; (void)c; (void)scale_a; (void)scale_b; if constexpr (false) {} #if defined(__gfx950__) else if constexpr DISPATCH_MFMA_SCALE_(fp8_t, fp8_t, fp32_t, 32, 32, 64, __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4) else if constexpr DISPATCH_MFMA_SCALE_(fp8_t, fp8_t, fp32_t, 16, 16, 128, __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4) else if constexpr DISPATCH_MFMA_SCALE_(fp4_t, fp4_t, fp32_t, 32, 32, 64, __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4) else if constexpr DISPATCH_MFMA_SCALE_(fp4_t, fp4_t, fp32_t, 16, 16, 128, __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4) #endif __builtin_unreachable(); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, int scale_a, int scale_b) { vtype_c c{0}; return operator()(a, b, c, scale_a, scale_b); } }; #undef DISPATCH_MFMA_ #undef DISPATCH_MFMA_F32_ #undef DISPATCH_MFMA_STEP_K_ #undef DISPATCH_MFMA_BF16_1K_ #undef DISPATCH_MFMA_STEP_K_BF16_1K_ #undef DISPATCH_MFMA_8BIT_ #undef DISPATCH_MFMA_SCALE_ using mfma_f32_32x32x2_f32 = mfma; using mfma_f32_16x16x4_f32 = mfma; using mfma_f32_32x32x8_f16 = mfma; using mfma_f32_16x16x16_f16 = mfma; using mfma_f32_32x32x8_bf16 = mfma; using mfma_f32_16x16x16_bf16 = mfma; using mfma_f32_32x32x16_f16 = mfma; using mfma_f32_16x16x32_f16 = mfma; using mfma_f32_32x32x16_bf16 = mfma; using mfma_f32_16x16x32_bf16 = mfma; using mfma_f32_32x32x16_fp8_fp8 = mfma; using mfma_f32_16x16x32_fp8_fp8 = mfma; using mfma_f32_32x32x16_bf8_bf8 = mfma; using mfma_f32_16x16x32_bf8_bf8 = mfma; // Scaled MFMA type aliases (gfx950 only, unified into struct mfma) using mfma_f32_32x32x64_fp8_fp8 = mfma; using mfma_f32_16x16x128_fp8_fp8 = mfma; using mfma_f32_32x32x64_fp4_fp4 = mfma; using mfma_f32_16x16x128_fp4_fp4 = mfma; // Backward-compatible aliases (deprecated: prefer mfma_f32_* above) using mfma_scale_f32_32x32x64_fp8_fp8 = mfma_f32_32x32x64_fp8_fp8; using mfma_scale_f32_16x16x128_fp8_fp8 = mfma_f32_16x16x128_fp8_fp8; using mfma_scale_f32_32x32x64_fp4_fp4 = mfma_f32_32x32x64_fp4_fp4; using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // wmma (gfx1250 / RDNA4, wave32) #if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ return inst_(false, a, false, b, static_cast(0), c, false, false); } // bf16f32 special: accumulator is f32 but output is bf16 => (neg_a, A, neg_b, B, fmts, C_f32, clamp, neg_c) // The builtin takes f32 accumulator and returns bf16 output; we store the f32 accum but return bf16. #define DISPATCH_WMMA_BF16F32_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ return inst_(false, a, false, b, static_cast(0), c, false, false); } // fp8/bf8 builtins: (A, B, matrix_fmts, C, clamp, neg_c) -- no neg_a/neg_b // A/B are packed as _ExtVector; bitcast from the fp8/bf8 vector #define DISPATCH_WMMA_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ return inst_(__builtin_bit_cast(vector_t, a), \ __builtin_bit_cast(vector_t, b), \ static_cast(0), c, false, false); } template struct wmma { using dtype_a = remove_cvref_t; using dtype_b = remove_cvref_t; using dtype_c = remove_cvref_t; static constexpr index_t wave_m = wave_m_; static constexpr index_t wave_n = wave_n_; static constexpr index_t wave_k = wave_k_; static constexpr index_t warp_size = warp_size_; // 32 for gfx1250 static constexpr index_t elem_a = wave_m * wave_k / warp_size; static constexpr index_t elem_b = wave_n * wave_k / warp_size; static constexpr index_t elem_c = wave_m * wave_n / warp_size; // For packed types (fp4), the hardware register packs multiple elements per byte. // elem counts logical elements; the register holds elem * bits_per_element / 8 bytes. // For non-packed types, sizeof(T) gives bytes per element directly. static constexpr index_t reg_bytes_a = is_packs_v ? (elem_a * sizeof_bits::value / 8) : (elem_a * static_cast(sizeof(dtype_a))); static constexpr index_t reg_bytes_b = is_packs_v ? (elem_b * sizeof_bits::value / 8) : (elem_b * static_cast(sizeof(dtype_b))); // vtype: for packed types, use i32 dword vector matching the hardware register size. // For non-packed types, use mfma_vtype_t (which gives ext_vector of the element type). using vtype_a = std::conditional_t, vector_t(sizeof(i32_t))>, mfma_vtype_t>; using vtype_b = std::conditional_t, vector_t(sizeof(i32_t))>, mfma_vtype_t>; using vtype_c = vector_t; // Format code for scaled WMMA (f8f6f4); -1 for types that don't support scaling static constexpr int fmt_a = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; static constexpr int fmt_b = std::is_same_v ? 0 : std::is_same_v ? 1 : std::is_same_v ? 4 : -1; // Regular (non-scaled) dispatch template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c) -> vtype_c { (void)a; (void)b; (void)c; if constexpr (false) {} #if defined(__gfx1250__) // f16/bf16 16x16x32 else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp32_t, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x32_f16) else if constexpr DISPATCH_WMMA_(fp16_t, fp16_t, fp16_t, 16, 16, 32, __builtin_amdgcn_wmma_f16_16x16x32_f16) else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, fp32_t, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x32_bf16) else if constexpr DISPATCH_WMMA_(bf16_t, bf16_t, bf16_t, 16, 16, 32, __builtin_amdgcn_wmma_bf16_16x16x32_bf16) // f32 16x16x4 else if constexpr DISPATCH_WMMA_(fp32_t, fp32_t, fp32_t, 16, 16, 4, __builtin_amdgcn_wmma_f32_16x16x4_f32) // fp8/bf8 16x16x64 -> f32 else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x64_fp8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x64_fp8_bf8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x64_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x64_bf8_bf8) // fp8/bf8 16x16x64 -> f16 else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp16_t, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x64_fp8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp16_t, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x64_fp8_bf8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x64_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x64_bf8_bf8) // fp8/bf8 16x16x128 -> f32 else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 128, __builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 128, __builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 128, __builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 128, __builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8) // fp8/bf8 16x16x128 -> f16 else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(fp8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif __builtin_unreachable(); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b) { vtype_c c{0}; return operator()(a, b, c); } // Scaled WMMA dispatch (gfx1250: f8f6f4 / f4 with E8M0 block-scale) // scale_a, scale_b are per-lane E8M0 exponent values; 127 = no scaling (2^0 = 1.0). // BX32: int -- 4 packed E8M0 bytes (byte 0 used with scale_sel=0, scale_fmt=0). // BX16: long -- 8 packed E8M0 bytes. // matrix_a_scale_sel controls OPSEL: 0=scale from lanes 0-15, 1=scale from lanes 16-31. // BX32 scaled dispatch template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, int scale_a, int scale_b, number = {}, number = {}) -> vtype_c { (void)a; (void)b; (void)c; (void)scale_a; (void)scale_b; if constexpr (false) {} #if defined(__gfx1250__) // 16x16x128 f8f6f4 (fp8/fp4 via format code): builtin always takes i32x16 else if constexpr (fmt_a >= 0 && fmt_b >= 0 && std::is_same_v && wave_m == 16 && wave_n == 16 && wave_k == 128) { // For packed types (fp4), vtype may be smaller than i32x16; zero-pad via union. auto pad_to_i32x16 = [](const auto& v) { if constexpr (sizeof(v) == sizeof(i32x16_t)) return __builtin_bit_cast(i32x16_t, v); else { union { i32x16_t w; char z[sizeof(i32x16_t)]; } u{}; __builtin_memcpy(&u, &v, sizeof(v)); return u.w; } }; return __builtin_amdgcn_wmma_scale_f32_16x16x128_f8f6f4( fmt_a, pad_to_i32x16(a), fmt_b, pad_to_i32x16(b), static_cast(0), c, a_scale_sel, 0, scale_a, b_scale_sel, 0, scale_b, false, false); } // 32x16x128 f4 (dedicated fp4 instruction): A=i32x16, B=i32x8 else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == 32 && wave_n == 16 && wave_k == 128) { return __builtin_amdgcn_wmma_scale_f32_32x16x128_f4( __builtin_bit_cast(i32x16_t, a), __builtin_bit_cast(i32x8_t, b), static_cast(0), c, a_scale_sel, 0, scale_a, b_scale_sel, 0, scale_b, false, false); } #endif __builtin_unreachable(); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, int scale_a, int scale_b, number = {}, number = {}) { vtype_c c{0}; return operator()(a, b, c, scale_a, scale_b, number{}, number{}); } // BX16 scaled dispatch (scale exponent is long = 64 bits = 8 packed E8M0 bytes) template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, long scale_a, long scale_b, number = {}, number = {}) -> vtype_c { (void)a; (void)b; (void)c; (void)scale_a; (void)scale_b; if constexpr (false) {} #if defined(__gfx1250__) // 16x16x128 f8f6f4 BX16 else if constexpr (fmt_a >= 0 && fmt_b >= 0 && std::is_same_v && wave_m == 16 && wave_n == 16 && wave_k == 128) { auto pad_to_i32x16 = [](const auto& v) { if constexpr (sizeof(v) == sizeof(i32x16_t)) return __builtin_bit_cast(i32x16_t, v); else { union { i32x16_t w; char z[sizeof(i32x16_t)]; } u{}; __builtin_memcpy(&u, &v, sizeof(v)); return u.w; } }; return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4( fmt_a, pad_to_i32x16(a), fmt_b, pad_to_i32x16(b), static_cast(0), c, a_scale_sel, 0, scale_a, b_scale_sel, 0, scale_b, false, false); } // 32x16x128 f4 BX16 else if constexpr (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == 32 && wave_n == 16 && wave_k == 128) { return __builtin_amdgcn_wmma_scale16_f32_32x16x128_f4( __builtin_bit_cast(i32x16_t, a), __builtin_bit_cast(i32x8_t, b), static_cast(0), c, a_scale_sel, 0, scale_a, b_scale_sel, 0, scale_b, false, false); } #endif __builtin_unreachable(); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, long scale_a, long scale_b, number = {}, number = {}) { vtype_c c{0}; return operator()(a, b, c, scale_a, scale_b, number{}, number{}); } }; #undef DISPATCH_WMMA_ #undef DISPATCH_WMMA_BF16F32_ #undef DISPATCH_WMMA_8BIT_ // f16/bf16 16x16x32 using wmma_f32_16x16x32_f16 = wmma; using wmma_f16_16x16x32_f16 = wmma; using wmma_f32_16x16x32_bf16 = wmma; using wmma_bf16_16x16x32_bf16 = wmma; // f32 16x16x4 using wmma_f32_16x16x4_f32 = wmma; // fp8/bf8 16x16x64 using wmma_f32_16x16x64_fp8_fp8 = wmma; using wmma_f32_16x16x64_fp8_bf8 = wmma; using wmma_f32_16x16x64_bf8_fp8 = wmma; using wmma_f32_16x16x64_bf8_bf8 = wmma; using wmma_f16_16x16x64_fp8_fp8 = wmma; using wmma_f16_16x16x64_fp8_bf8 = wmma; using wmma_f16_16x16x64_bf8_fp8 = wmma; using wmma_f16_16x16x64_bf8_bf8 = wmma; // fp8/bf8 16x16x128 using wmma_f32_16x16x128_fp8_fp8 = wmma; using wmma_f32_16x16x128_fp8_bf8 = wmma; using wmma_f32_16x16x128_bf8_fp8 = wmma; using wmma_f32_16x16x128_bf8_bf8 = wmma; using wmma_f16_16x16x128_fp8_fp8 = wmma; using wmma_f16_16x16x128_fp8_bf8 = wmma; using wmma_f16_16x16x128_bf8_fp8 = wmma; using wmma_f16_16x16x128_bf8_bf8 = wmma; // Scaled WMMA (f8f6f4 unified instruction, supports fp8/bf8/fp4 via format code) using wmma_scale_f32_16x16x128_fp8_fp8 = wmma; using wmma_scale_f32_16x16x128_fp4_fp4 = wmma; // Scaled WMMA (dedicated fp4 32x16x128 instruction) using wmma_scale_f32_32x16x128_fp4_fp4 = wmma; #endif // __gfx1250__ (wmma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // adaptor struct p_dim {}; struct y_dim {}; namespace impl{ // utlity function to play with shape template static constexpr auto pickup_filter(seq<>) { return seq<>{}; } template static constexpr auto pickup_filter(seq) { if constexpr (std::is_same_v(FDim{}))>, remove_cvref_t>) return concat_seq(seq{}, pickup_filter(seq{})); else return pickup_filter(seq{}); } template OPUS_D static constexpr auto pickup_shape_apply(seq) { return opus::make_tuple(get(Shape{})...); } template OPUS_D static constexpr auto pickup_shape_impl(const Shape&, const FDim&, Target, seq) { static_assert(size() == size()); return pickup_shape_apply(pickup_filter(seq{})); } template OPUS_D constexpr index_t dim_group_size_sum(seq) { return (static_cast(get(Dim{}).size()) + ... + 0); } template OPUS_D constexpr auto unflatten_shape_group(seq) { constexpr index_t SStart = dim_group_size_sum(make_index_seq{}); return opus::make_tuple(get(Shape{})...); } template OPUS_D constexpr auto unflatten_shape_impl(seq) { return opus::make_tuple(unflatten_shape_group(make_index_seq(Dim{}).size()>{})...); } template OPUS_D constexpr index_t p_count_in(seq) { return ((std::is_same_v(Dim{}))>, p_dim> ? 1 : 0) + ... + 0); } template OPUS_D constexpr auto unfold_p_coord_impl(const Coord& coord, seq) { return opus::make_tuple( [&]() -> decltype(auto) { if constexpr (std::is_same_v(Dim{}))>, p_dim>) return get< p_count_in(make_index_seq{}) >(coord); else return underscore{}; }()... ); } template OPUS_D constexpr index_t dim_offset_sum(seq) { return (static_cast(size(Dim{}))>()) + ... + 0); } template OPUS_D constexpr auto unfold_x_stride_each(const Stride& stride) { constexpr index_t C = dim_offset_sum(make_index_seq{}); constexpr index_t len = size(Dim{}))>(); constexpr auto current_shape = slice(Shape{}, number{}, number{}); constexpr auto current_stride = packed_shape_to_stride(current_shape); return transform_tuple([&](auto i_elem){ return i_elem * get(stride); }, current_stride); } template constexpr index_t unfold_find_group(seq) { index_t acc = 0, r = 0; ((void)(acc += size(Dim{}))>(), (acc <= J ? (void)(r = Gs + 1) : (void)0)), ...); return r; } template OPUS_D constexpr auto unfold_x_stride_at(const Stride& stride) { constexpr index_t G = unfold_find_group(make_index_seq()>{}); constexpr index_t group_end = dim_offset_sum(make_index_seq{}); return packed_stride_at(make_index_seq{}) * get(stride); } template OPUS_D constexpr auto unfold_x_stride_flat(const Stride& stride, seq) { return opus::make_tuple(unfold_x_stride_at(stride)...); } template OPUS_D constexpr auto unfold_x_stride_impl(const Stride& stride, seq) { return unfold_x_stride_flat(stride, make_index_seq()>{}); } } template OPUS_D static constexpr auto pickup_shape(const Shape&, const Dim&, Target) { return pickup_shape_impl(Shape{}, flatten_tuple(Dim{}), Target{}, make_index_seq()>{}); } // Shape : tuple // Dim : tuple, tuple<*, *, *>, tuple<*>> // => : tuple, tuple, tuple> template OPUS_D constexpr auto unflatten_shape(const Shape&, const Dim&) { return impl::unflatten_shape_impl(make_index_seq()>{}); } // coord: tuple, dim: tuple, tuple> -> tuple template OPUS_D constexpr auto unfold_p_coord(const Dim&, const Coord& coord) { constexpr auto flatten_dim = flatten_tuple(Dim{}); using FDim = remove_cvref_t; static_assert(tuple_count(flatten_dim) == size(), "input coord must be same size as p_dim inside Dim"); return impl::unfold_p_coord_impl(coord, make_index_seq()>{}); } template OPUS_D constexpr auto unfold_x_stride(const Dim&, const Shape&, const Stride& stride) { constexpr auto flatten_dim = flatten_tuple(Dim{}); static_assert(size() == size(), "input stride must be same size as x_dim"); static_assert(size() == size>(), "input shape must be same size as flattened dim"); return impl::unfold_x_stride_impl(stride, make_index_seq()>{}); } #define OPUS_KP_(x_) static_assert(opus::tuple_count(opus::flatten_tuple(x_ ())) == size()) // Per-axis layout API: generates y_shape_X, p_shape_X, layout_X (3 overloads), layout_X_packed, y_layout_X #define OPUS_ADAPTOR_LAYOUT_API_DEFINE_FOR(X) \ OPUS_D static constexpr auto y_shape_##X() { return y_shape(shape_##X(), dim_##X()); } \ OPUS_D static constexpr auto p_shape_##X() { return p_shape(shape_##X(), dim_##X()); } \ template OPUS_D constexpr auto layout_##X() { return make_layout(shape_##X());} \ template OPUS_D constexpr auto layout_##X(S&& stride) { return make_layout(shape_##X(), unfold_x_stride(dim_##X(), shape_##X(), stride));} \ template OPUS_D constexpr auto layout_##X(S&& stride, C&& z) { OPUS_KP_(dim_##X); return make_layout(shape_##X(), unfold_x_stride(dim_##X(), shape_##X(), stride), opus::unfold_p_coord(dim_##X(), z));} \ template OPUS_D constexpr auto layout_##X##_packed(C&& z) { OPUS_KP_(dim_##X); return make_layout_packed(shape_##X(), opus::unfold_p_coord(dim_##X(), z));} \ template && ...), bool> = true> OPUS_D constexpr auto layout_##X(Ts&&... strides) {return layout_##X(opus::make_tuple(strides...)); } \ template OPUS_D constexpr auto y_layout_##X() { return make_layout(y_shape_##X());} // any struct implement adaptor like feature must implement(or using from base) shape_a/b/c, dim_a/b/c #define OPUS_ADAPTOR_LAYOUT_API_DEFINE \ template OPUS_D static constexpr auto y_shape(const S& /*shape*/, const D& /*dim*/) { return opus::pickup_shape(S{}, D{}, y_dim{}); } \ template OPUS_D static constexpr auto p_shape(const S& /*shape*/, const D& /*dim*/) { return opus::pickup_shape(S{}, D{}, p_dim{}); } \ OPUS_ADAPTOR_LAYOUT_API_DEFINE_FOR(a) \ OPUS_ADAPTOR_LAYOUT_API_DEFINE_FOR(b) \ OPUS_ADAPTOR_LAYOUT_API_DEFINE_FOR(c) // Note: any class to support adaptor need include OPUS_ADAPTOR_LAYOUT_API_DEFINE and implement shape_a()/shape_b()/shape_c() // P indicates dim cross thread, Y indicates dim within thread, this is X layout (X=P+Y) view the tensor as a whole // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(rept_c, grpm_c

, pack_c), (grpn_c

)], MxN #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) namespace impl { template struct mfma_adaptor : public remove_cvref_t { using mfma_type = remove_cvref_t; static constexpr index_t grpm_a = mfma_type::wave_m; static constexpr index_t grpn_b = mfma_type::wave_n; static_assert(mfma_type::warp_size % grpm_a == 0 && mfma_type::warp_size % grpn_b == 0 && grpm_a == grpn_b); static constexpr index_t grpk_a = mfma_type::warp_size / grpm_a; static constexpr index_t grpk_b = grpk_a; static constexpr index_t grpn_c = mfma_type::wave_n; static constexpr index_t grpm_c = mfma_type::warp_size / grpn_c; static constexpr index_t max_pack_a = 16 / sizeof(typename mfma_type::dtype_a); // max 4 dwords static constexpr index_t max_pack_b = 16 / sizeof(typename mfma_type::dtype_b); // max 4 dwords static constexpr index_t max_pack_c = 16 / sizeof(typename mfma_type::dtype_c); // max 4 dwords // pack_* should be vector load from ds_read/global_read static constexpr index_t pack_a = (max_pack_a < mfma_type::elem_a ? max_pack_a : mfma_type::elem_a); static constexpr index_t pack_b = (max_pack_b < mfma_type::elem_b ? max_pack_b : mfma_type::elem_b); static constexpr index_t pack_c = (max_pack_c < mfma_type::elem_c ? max_pack_c : mfma_type::elem_c); static constexpr index_t rept_a = mfma_type::elem_a / pack_a; static constexpr index_t rept_b = mfma_type::elem_b / pack_b; static constexpr index_t rept_c = mfma_type::elem_c / pack_c; // by default, this is X shape, P + Y OPUS_D static constexpr auto shape_a() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto shape_b() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto shape_c() { return tuple, number, number, number>{}; } // here we describe above shape by group them into a 2d shape style, and with p/y dim. we could put into same structure, but let's make things easier OPUS_D static constexpr auto dim_a() { return tuple< tuple, tuple >{}; } // dim encoding for A, MxK OPUS_D static constexpr auto dim_b() { return tuple< tuple, tuple >{}; } // dim encoding for B, NxK OPUS_D static constexpr auto dim_c() { return tuple< tuple, tuple >{}; } // dim encoding for C, MxN OPUS_ADAPTOR_LAYOUT_API_DEFINE }; // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpn_c

), (rept_c, grpm_c

, pack_c)], MxN transposed(!) template struct mfma_adaptor_swap_ab : mfma_adaptor { using base = mfma_adaptor; using base::shape_a; using base::shape_b; using base::dim_a; using base::dim_b; using base::y_shape; using base::p_shape; using base::y_shape_a; using base::y_shape_b; using base::p_shape_a; using base::p_shape_b; using base::layout_a; using base::layout_b; using base::layout_a_packed; using base::layout_b_packed; using base::y_layout_a; using base::y_layout_b; OPUS_D static constexpr auto shape_c() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto dim_c() { return tuple, tuple >{}; } // dim encoding for C, MxN // Only generate _c layout methods (shape_c/dim_c changed) OPUS_ADAPTOR_LAYOUT_API_DEFINE_FOR(c) template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, number = {}, number = {}, number = {}) { return base::operator()(b, a, c, number{}, number{}, number{}); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, number = {}, number = {}, number = {}) { typename MFMA::vtype_c c{0}; return operator()(a, b, c, number{}, number{}, number{}); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) { return base::operator()(b, a, c, scale_b, scale_a); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, int scale_a, int scale_b) { typename MFMA::vtype_c c{0}; return operator()(a, b, c, scale_a, scale_b); } }; } // helper class to create adaptor instance for mfma, need be paired with make_mfma(). don't directly use it struct mfma_adaptor { template OPUS_D decltype(auto) operator()(M&&) { return impl::mfma_adaptor>{};} }; struct mfma_adaptor_swap_ab { template OPUS_D decltype(auto) operator()(M&&) { return impl::mfma_adaptor_swap_ab>{};} }; template OPUS_D decltype(auto) make_mfma(number, number, number, A&& = {}, number = {}) { return A{}(mfma{}); } template*/, typename A = mfma_adaptor, index_t warp_size_ = get_warp_size()> OPUS_D decltype(auto) make_mfma(WaveMNK&&, A&& = {}, number = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ // wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250) // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN #if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) namespace impl { template struct wmma_adaptor : public remove_cvref_t { using wmma_type = remove_cvref_t; static constexpr index_t grpm_a = wmma_type::wave_m; static constexpr index_t grpn_b = wmma_type::wave_n; static_assert(wmma_type::warp_size % grpm_a == 0 && wmma_type::warp_size % grpn_b == 0 && grpm_a == grpn_b); static constexpr index_t grpk_a = wmma_type::warp_size / grpm_a; static constexpr index_t grpk_b = grpk_a; static constexpr index_t grpn_c = wmma_type::wave_n; static constexpr index_t grpm_c = wmma_type::warp_size / grpn_c; static constexpr index_t max_pack_a = 16 / sizeof(typename wmma_type::dtype_a); static constexpr index_t max_pack_b = 16 / sizeof(typename wmma_type::dtype_b); static constexpr index_t max_pack_c = 16 / sizeof(typename wmma_type::dtype_c); static constexpr index_t pack_a = (max_pack_a < wmma_type::elem_a ? max_pack_a : wmma_type::elem_a); static constexpr index_t pack_b = (max_pack_b < wmma_type::elem_b ? max_pack_b : wmma_type::elem_b); static constexpr index_t pack_c = (max_pack_c < wmma_type::elem_c ? max_pack_c : wmma_type::elem_c); static constexpr index_t rept_a = wmma_type::elem_a / pack_a; static constexpr index_t rept_b = wmma_type::elem_b / pack_b; static constexpr index_t rept_c = wmma_type::elem_c / pack_c; OPUS_D static constexpr auto shape_a() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto shape_b() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto shape_c() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto dim_a() { return tuple< tuple, tuple >{}; } OPUS_D static constexpr auto dim_b() { return tuple< tuple, tuple >{}; } OPUS_D static constexpr auto dim_c() { return tuple< tuple, tuple >{}; } OPUS_ADAPTOR_LAYOUT_API_DEFINE }; template struct wmma_adaptor_swap_ab : wmma_adaptor { using base = wmma_adaptor; using base::shape_a; using base::shape_b; using base::dim_a; using base::dim_b; using base::y_shape; using base::p_shape; using base::y_shape_a; using base::y_shape_b; using base::p_shape_a; using base::p_shape_b; using base::layout_a; using base::layout_b; using base::layout_a_packed; using base::layout_b_packed; using base::y_layout_a; using base::y_layout_b; OPUS_D static constexpr auto shape_c() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto dim_c() { return tuple, tuple >{}; } // Only generate _c layout methods (shape_c/dim_c changed) OPUS_ADAPTOR_LAYOUT_API_DEFINE_FOR(c) template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c) { return base::operator()(b, a, c); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b) { typename WMMA::vtype_c c{0}; return operator()(b, a, c); } // Scaled overloads (BX32 / BX16): swap a,b then forward to base template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) { return base::operator()(b, a, c, scale_a, scale_b); } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, long scale_a, long scale_b) { return base::operator()(b, a, c, scale_a, scale_b); } }; } // namespace impl (wmma_adaptor) struct wmma_adaptor { template OPUS_D decltype(auto) operator()(M&&) { return impl::wmma_adaptor>{};} }; struct wmma_adaptor_swap_ab { template OPUS_D decltype(auto) operator()(M&&) { return impl::wmma_adaptor_swap_ab>{};} }; template OPUS_D decltype(auto) make_wmma(number, number, number, A&& = {}, number = {}) { return A{}(wmma{}); } template OPUS_D decltype(auto) make_wmma(WaveMNK&&, A&& = {}, number = {}) { return A{}(wmma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __gfx1250__ ///////////////////////////////////////////////////////////////////////////////////////////////////////// namespace impl { // tiled mma, warp level mfma/wmma/... EXPAND_: each wave need repeat along m/n/k dim how many times. TILE_: number of waves in m/n/k dim // A:[(expd_m, tile_m

), (expd_k, tile_k

)] // B:[(expd_n, tile_n

), (expd_k, tile_k

)] // C:[(expd_m, tile_m

), (expd_n, tile_n

)] template struct tiled_mma_adaptor : public MMA_ { using MMA = remove_cvref_t; static constexpr index_t expd_m = EXPAND_M; static constexpr index_t expd_n = EXPAND_N; static constexpr index_t expd_k = EXPAND_K; static constexpr index_t tile_m = TILE_M; static constexpr index_t tile_n = TILE_N; static constexpr index_t tile_k = TILE_K; #if OPUS_TILE_CONTAINER == 0 using vtype_a = vector_t; using vtype_b = vector_t; using vtype_c = vector_t; #elif OPUS_TILE_CONTAINER == 1 using vtype_a = array; using vtype_b = array; using vtype_c = array; #endif OPUS_D static constexpr auto tile_shape_a() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto tile_shape_b() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto tile_shape_c() { return tuple, number, number, number>{}; } OPUS_D static constexpr auto tile_dim_a() { return tuple< tuple, tuple >{}; } // dim encoding for A, MxK OPUS_D static constexpr auto tile_dim_b() { return tuple< tuple, tuple >{}; } // dim encoding for B, NxK OPUS_D static constexpr auto tile_dim_c() { return tuple< tuple, tuple >{}; } // dim encoding for C, MxN OPUS_D static constexpr auto shape_a() { return flatten_tuple(embed_nested_tuple(unflatten_shape(tile_shape_a(), tile_dim_a()), unflatten_shape(MMA::shape_a(), MMA::dim_a()))); } OPUS_D static constexpr auto shape_b() { return flatten_tuple(embed_nested_tuple(unflatten_shape(tile_shape_b(), tile_dim_b()), unflatten_shape(MMA::shape_b(), MMA::dim_b()))); } OPUS_D static constexpr auto shape_c() { return flatten_tuple(embed_nested_tuple(unflatten_shape(tile_shape_c(), tile_dim_c()), unflatten_shape(MMA::shape_c(), MMA::dim_c()))); } OPUS_D static constexpr auto dim_a() { return embed_nested_tuple(tile_dim_a(), MMA::dim_a()); } // dim encoding for A, MxK OPUS_D static constexpr auto dim_b() { return embed_nested_tuple(tile_dim_b(), MMA::dim_b()); } // dim encoding for A, MxK OPUS_D static constexpr auto dim_c() { return embed_nested_tuple(tile_dim_c(), MMA::dim_c()); } // dim encoding for A, MxK // Cached tile sizes (avoids re-evaluating y_shape + reduce_tuple_mul in every operator/step_k) static constexpr index_t mma_a_len = get<0>(reduce_tuple_mul(MMA::y_shape_a())).value; static constexpr index_t mma_b_len = get<0>(reduce_tuple_mul(MMA::y_shape_b())).value; static constexpr index_t mma_c_len = get<0>(reduce_tuple_mul(MMA::y_shape_c())).value; static constexpr index_t tile_a_len = EXPAND_M * EXPAND_K * mma_a_len; static constexpr index_t tile_b_len = EXPAND_N * EXPAND_K * mma_b_len; static constexpr index_t tile_c_len = EXPAND_M * EXPAND_N * mma_c_len; // input a/b/c is array of ext type e.g. "fp16x2_t a[2];", pass "a" to this function template > && is_array_v< remove_cvref_t > && is_array_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, number = {}, number = {}, number = {}) { VC c_ {c}; for (index_t I = 0; I < EXPAND_K * EXPAND_M * EXPAND_N; I++) { index_t i_k = I / (EXPAND_M * EXPAND_N), i_m = (I / EXPAND_N) % EXPAND_M, i_n = I % EXPAND_N; auto s_a = a[i_m * EXPAND_K + i_k]; auto s_b = b[i_n * EXPAND_K + i_k]; auto s_c = c_[i_m * EXPAND_N + i_n]; s_c = MMA{}(s_a, s_b, s_c); c_[i_m * EXPAND_N + i_n] = s_c; } return c_; } template > && is_vector_v< remove_cvref_t > && is_vector_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, number = {}, number = {}, number = {}) { static_assert(size() == tile_a_len); static_assert(size() == tile_b_len); static_assert(size() == tile_c_len); constexpr index_t a_len = mma_a_len, b_len = mma_b_len, c_len = mma_c_len; VC c_ {c}; for (index_t I = 0; I < EXPAND_K * EXPAND_M * EXPAND_N; I++) { index_t i_k = I / (EXPAND_M * EXPAND_N), i_m = (I / EXPAND_N) % EXPAND_M, i_n = I % EXPAND_N; index_t i_a = (i_m * EXPAND_K + i_k) * a_len, i_b = (i_n * EXPAND_K + i_k) * b_len, i_c = (i_m * EXPAND_N + i_n) * c_len; typename MMA::vtype_a s_a; for (index_t j = 0; j < a_len; j++) s_a[j] = a[i_a + j]; typename MMA::vtype_b s_b; for (index_t j = 0; j < b_len; j++) s_b[j] = b[i_b + j]; typename MMA::vtype_c s_c; for (index_t j = 0; j < c_len; j++) s_c[j] = c_[i_c + j]; s_c = MMA{}(s_a, s_b, s_c); for (index_t j = 0; j < c_len; j++) c_[i_c + j] = s_c[j]; } return c_; } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, number = {}, number = {}, number = {}) { vtype_c c{0}; return operator()(a, b, c, number{}, number{}, number{}); } // Scaled MFMA (f8f6f4): forward scale_a, scale_b to underlying MMA template > && is_array_v< remove_cvref_t > && is_array_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) { VC c_ {c}; for (index_t I = 0; I < EXPAND_K * EXPAND_M * EXPAND_N; I++) { index_t i_k = I / (EXPAND_M * EXPAND_N), i_m = (I / EXPAND_N) % EXPAND_M, i_n = I % EXPAND_N; auto s_a = a[i_m * EXPAND_K + i_k]; auto s_b = b[i_n * EXPAND_K + i_k]; auto s_c = c_[i_m * EXPAND_N + i_n]; s_c = MMA{}(s_a, s_b, s_c, scale_a, scale_b); c_[i_m * EXPAND_N + i_n] = s_c; } return c_; } template > && is_vector_v< remove_cvref_t > && is_vector_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto operator()(const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) { static_assert(size() == tile_a_len); static_assert(size() == tile_b_len); static_assert(size() == tile_c_len); constexpr index_t a_len = mma_a_len, b_len = mma_b_len, c_len = mma_c_len; VC c_ {c}; static_ford([&](auto i_k, auto i_m, auto i_n){ constexpr index_t i_tile_a = i_m * EXPAND_K + i_k; constexpr index_t i_tile_b = i_n * EXPAND_K + i_k; constexpr index_t i_tile_c = i_m * EXPAND_N + i_n; auto s_a = slice(a, number{}, number{}); auto s_b = slice(b, number{}, number{}); auto s_c = slice(c_, number{}, number{}); s_c = MMA{}(s_a, s_b, s_c, scale_a, scale_b); set_slice(c_, s_c, number{}, number{}); }); return c_; } template OPUS_D constexpr auto operator()(const VA& a, const VB& b, int scale_a, int scale_b) { vtype_c c{0}; return operator()(a, b, c, scale_a, scale_b); } template > && is_array_v< remove_cvref_t > && is_array_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto step_k(number, const VA& a, const VB& b, const VC& c, number = {}, number = {}, number = {}) { static_assert(STEP_K < EXPAND_K); VC c_ {c}; static_for([&](auto I){ constexpr index_t i_m = I.value / EXPAND_N, i_n = I.value % EXPAND_N; auto s_a = a[i_m * EXPAND_K + STEP_K]; auto s_b = b[i_n * EXPAND_K + STEP_K]; auto s_c = c_[i_m * EXPAND_N + i_n]; s_c = MMA{}(s_a, s_b, s_c); c_[i_m * EXPAND_N + i_n] = s_c; }); return c_; } template > && is_vector_v< remove_cvref_t > && is_vector_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto step_k(number, const VA& a, const VB& b, const VC& c, number = {}, number = {}, number = {}) { static_assert(STEP_K < EXPAND_K); static_assert(size() == tile_a_len); static_assert(size() == tile_b_len); static_assert(size() == tile_c_len); constexpr index_t a_len = mma_a_len, b_len = mma_b_len, c_len = mma_c_len; VC c_ {c}; for (index_t I = 0; I < EXPAND_M * EXPAND_N; I++) { index_t i_m = I / EXPAND_N, i_n = I % EXPAND_N; index_t i_a = (i_m * EXPAND_K + STEP_K) * a_len, i_b = (i_n * EXPAND_K + STEP_K) * b_len, i_c = (i_m * EXPAND_N + i_n) * c_len; typename MMA::vtype_a s_a; for (index_t j = 0; j < a_len; j++) s_a[j] = a[i_a + j]; typename MMA::vtype_b s_b; for (index_t j = 0; j < b_len; j++) s_b[j] = b[i_b + j]; typename MMA::vtype_c s_c; for (index_t j = 0; j < c_len; j++) s_c[j] = c_[i_c + j]; s_c = MMA{}(s_a, s_b, s_c); for (index_t j = 0; j < c_len; j++) c_[i_c + j] = s_c[j]; } return c_; } template OPUS_D constexpr auto step_k(number step, const VA& a, const VB& b, number = {}, number = {}, number = {}) { vtype_c c{0}; return step_k(step, a, b, c, number{}, number{}, number{}); } template > && is_array_v< remove_cvref_t > && is_array_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto step_k(number, const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) { static_assert(STEP_K < EXPAND_K); VC c_ {c}; for (index_t I = 0; I < EXPAND_M * EXPAND_N; I++) { index_t i_m = I / EXPAND_N, i_n = I % EXPAND_N; auto s_a = a[i_m * EXPAND_K + STEP_K]; auto s_b = b[i_n * EXPAND_K + STEP_K]; auto s_c = c_[i_m * EXPAND_N + i_n]; s_c = MMA{}(s_a, s_b, s_c, scale_a, scale_b); c_[i_m * EXPAND_N + i_n] = s_c; } return c_; } template > && is_vector_v< remove_cvref_t > && is_vector_v< remove_cvref_t >), bool > = true> OPUS_D constexpr auto step_k(number, const VA& a, const VB& b, const VC& c, int scale_a, int scale_b) { static_assert(STEP_K < EXPAND_K); static_assert(size() == tile_a_len); static_assert(size() == tile_b_len); static_assert(size() == tile_c_len); constexpr index_t a_len = mma_a_len, b_len = mma_b_len, c_len = mma_c_len; VC c_ {c}; for (index_t I = 0; I < EXPAND_M * EXPAND_N; I++) { index_t i_m = I / EXPAND_N, i_n = I % EXPAND_N; index_t i_a = (i_m * EXPAND_K + STEP_K) * a_len, i_b = (i_n * EXPAND_K + STEP_K) * b_len, i_c = (i_m * EXPAND_N + i_n) * c_len; typename MMA::vtype_a s_a; for (index_t j = 0; j < a_len; j++) s_a[j] = a[i_a + j]; typename MMA::vtype_b s_b; for (index_t j = 0; j < b_len; j++) s_b[j] = b[i_b + j]; typename MMA::vtype_c s_c; for (index_t j = 0; j < c_len; j++) s_c[j] = c_[i_c + j]; s_c = MMA{}(s_a, s_b, s_c, scale_a, scale_b); for (index_t j = 0; j < c_len; j++) c_[i_c + j] = s_c[j]; } return c_; } template OPUS_D constexpr auto step_k(number step, const VA& a, const VB& b, int scale_a, int scale_b) { vtype_c c{0}; return step_k(step, a, b, c, scale_a, scale_b); } OPUS_ADAPTOR_LAYOUT_API_DEFINE }; } struct tiled_mma_adaptor { template OPUS_D decltype(auto) operator()(MMA&&, number...) { return impl::tiled_mma_adaptor, Ts...>{};} }; template OPUS_D decltype(auto) make_tiled_mma(MMA&& mma, number, number, number, number, number, number, A&& = {}) { return A{}(std::forward(mma), number{}, number{}, number{}, number{}, number{}, number{}); } template OPUS_D decltype(auto) make_tiled_mma(MMA&& mma, ES, TS, A&& = {}) { return A{}(std::forward(mma), number(ES{})>{}, number(ES{})>{}, number(ES{})>{}, number(TS{})>{}, number(TS{})>{}, number(TS{})>{}); } template OPUS_D decltype(auto) make_tiled_mma(ES, TS, WS, WA&& = {}, TA&& = {}) { #if defined(__gfx1250__) return TA{}(make_wmma(WS{}, WA{}, number{}), #else return TA{}(make_mfma(WS{}, WA{}, number{}), #endif number(ES{})>{}, number(ES{})>{}, number(ES{})>{}, number(TS{})>{}, number(TS{})>{}, number(TS{})>{}); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// template && is_tuple_v && is_tuple_v && is_tuple_v, bool> = true> OPUS_D constexpr auto partition_layout(L&& layout, D&& dims, S&& shapes, C&& p_coord) { OPUS_KP_(dims); return make_layout(std::forward(shapes), unfold_x_stride(std::forward(dims), std::forward(shapes), layout.stride()), unfold_p_coord(std::forward(dims), p_coord)); } // partition, use cached_vec to dispatch which layout implementation. cached_vec < 0 : "layout", cached_vec == 0 : "layout_linear", cached_vec > 0 : "layout_cached" template OPUS_D constexpr auto partition_layout_a(M&& mma) { return mma.template layout_a(); } template OPUS_D constexpr auto partition_layout_b(M&& mma) { return mma.template layout_b(); } template OPUS_D constexpr auto partition_layout_c(M&& mma) { return mma.template layout_c(); } template, bool> = true> OPUS_D constexpr auto partition_layout_a(M&& mma, S&& x_stride) { return mma.template layout_a(std::forward(x_stride)); } template, bool> = true> OPUS_D constexpr auto partition_layout_b(M&& mma, S&& x_stride) { return mma.template layout_b(std::forward(x_stride)); } template, bool> = true> OPUS_D constexpr auto partition_layout_c(M&& mma, S&& x_stride) { return mma.template layout_c(std::forward(x_stride)); } template && is_tuple_v, bool> = true> OPUS_D constexpr auto partition_layout_a(M&& mma, S&& x_stride, C&& p_coord) { return mma.template layout_a(std::forward(x_stride), std::forward(p_coord)); } template && is_tuple_v, bool> = true> OPUS_D constexpr auto partition_layout_b(M&& mma, S&& x_stride, C&& p_coord) { return mma.template layout_b(std::forward(x_stride), std::forward(p_coord)); } template && is_tuple_v, bool> = true> OPUS_D constexpr auto partition_layout_c(M&& mma, S&& x_stride, C&& p_coord) { return mma.template layout_c(std::forward(x_stride), std::forward(p_coord)); } template, bool> = true> OPUS_D constexpr auto partition_layout_a_packed(M&& mma, C&& p_coord) { return mma.template layout_a_packed(std::forward(p_coord)); } template, bool> = true> OPUS_D constexpr auto partition_layout_b_packed(M&& mma, C&& p_coord) { return mma.template layout_b_packed(std::forward(p_coord)); } template, bool> = true> OPUS_D constexpr auto partition_layout_c_packed(M&& mma, C&& p_coord) { return mma.template layout_c_packed(std::forward(p_coord)); } #undef OPUS_KP_ } // namespace opus // call this macro within your kernel body to have fast access to opus types #define OPUS_USING_COMMON_TYPES \ using opus::operator""_I; \ using p_dim = opus::p_dim; \ using y_dim = opus::y_dim; // call this macro in global scope (outside of your kernel function, or under structure) #define OPUS_USING_COMMON_TYPES_ALL \ OPUS_USING_COMMON_TYPES \ template using num = opus::number; \ template using tup = opus::tuple; \ template using seq = opus::seq; // clang-format on