Unverified Commit d480a5a6 authored by Max Podkorytov's avatar Max Podkorytov Committed by GitHub
Browse files

Merge branch 'develop' into ck-flex

parents bca939ce 9c5b2f39
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#pragma once #pragma once
#include <cstdlib> #include <cstdlib>
...@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) ...@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
} }
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y) ...@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{ {
if constexpr(predicate) if constexpr(predicate)
{ {
return std::forward<X>(x); return ck::forward<X>(x);
} }
else else
{ {
return std::forward<Y>(y); return ck::forward<Y>(y);
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP #ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP
...@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>> ...@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template <typename F, typename X> template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const __host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{ {
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...); return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...);
} }
}; };
...@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>> ...@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template <typename F, typename X, typename Y> template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{ {
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})..., return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...); ck::forward<Y>(y).At(Number<Js>{})...);
} }
}; };
...@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x) ...@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{ {
using X_ = remove_reference_t<X>; using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}( return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x)); ck::forward<F>(f), ck::forward<X>(x));
} }
// TODO: properly implement unpack that takes any number of containers // TODO: properly implement unpack that takes any number of containers
...@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) ...@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using Y_ = remove_reference_t<Y>; using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type, return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y)); ck::forward<F>(f), ck::forward<X>(x), ck::forward<Y>(y));
} }
} // namespace ck } // namespace ck
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_ ...@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return integral_constant<decltype(X % Y), X % Y>{}; return integral_constant<decltype(X % Y), X % Y>{};
} }
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/integral_constant.hpp"
namespace ck { namespace ck {
namespace detail { namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args> template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector struct detector
{ {
using value_t = std::false_type; using value_t = integral_constant<bool, false>;
using type = Default; using type = Default;
}; };
template <class Default, template <class...> class Op, class... Args> template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...> struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
{ {
using value_t = std::true_type; using value_t = integral_constant<bool, true>;
using type = Op<Args...>; using type = Op<Args...>;
}; };
} // namespace detail } // namespace detail
...@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args> ...@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t; using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T> template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable); using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
template <typename T> template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable); using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
template <typename T> template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable); using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#include <ostream> #include <ostream>
#endif
#pragma once #pragma once
...@@ -25,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler() ...@@ -25,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
} // namespace ck } // namespace ck
#ifndef CK_CODE_GEN_RTC
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{ {
switch(s) switch(s)
...@@ -35,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) ...@@ -35,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
} }
return os; return os;
} }
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#include "type.hpp" #include "type.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace ck { namespace ck {
// magic number division // magic number division
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float); ...@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#endif #endif
// math functions for the host, some are implemented by calling C++ std functions // math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -459,7 +459,7 @@ inline __host__ double expm1<double>(double x) ...@@ -459,7 +459,7 @@ inline __host__ double expm1<double>(double x)
{ {
return std::expm1(x); return std::expm1(x);
} }
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); }; static inline __device__ float abs(float x) { return ::abs(x); };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <ck/utility/ignore.hpp>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
using uint8_t = unsigned char;
using uint16_t = unsigned short;
using uint32_t = unsigned int;
#endif
namespace ck { namespace ck {
// Pseudo random number generator // Pseudo random number generator
// version for fp32 // version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false> template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint32_t x = *(reinterpret_cast<uint32_t*>(&val)); uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
...@@ -25,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -25,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// version for fp16 // version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false> template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
...@@ -40,15 +45,14 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -40,15 +45,14 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// return 0 if data is not fp16 or fp32 // return 0 if data is not fp16 or fp32
template < template <typename T,
typename T, uint32_t seed_t,
uint32_t seed_t, ck::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{ {
std::ignore = id; ck::ignore = id;
std::ignore = val; ck::ignore = val;
std::ignore = seed; ck::ignore = seed;
return 0; return 0;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef CK_CODE_GEN_RTC
#include <ostream> #include <ostream>
#endif
#include "ck/utility/integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
...@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type; ...@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
#ifndef CK_CODE_GEN_RTC
template <ck::index_t... Is> template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>) std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{ {
...@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>) ...@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os << S::At(S::Size() - ck::Number<1>{}).value << "}"; os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os; return os;
} }
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
...@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x) ...@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if // is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize> // using MultiIndex<NSize>
// TODO: how to fix this? // TODO: how to fix this?
template < template <typename... Ys,
typename... Ys, typename X,
typename X, enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) ...@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template < template <typename... Ys,
typename... Ys, typename X,
typename X, enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) ...@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template < template <typename... Xs,
typename... Xs, typename Y,
typename Y, enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) ...@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template < template <typename... Xs,
typename... Xs, typename Y,
typename Y, enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) ...@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template < template <typename... Xs,
typename... Xs, typename Y,
typename Y, enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
// MultiIndex = scalar * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false> enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x) __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
{ {
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
...@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x) ...@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
// MultiIndex = MultiIndex * scalar // MultiIndex = MultiIndex * scalar
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false> enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a) __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
{ {
return a * x; return a * x;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -32,7 +32,7 @@ struct TupleElementKeyData ...@@ -32,7 +32,7 @@ struct TupleElementKeyData
template <typename T, template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value, typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
{ {
} }
...@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x) ...@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return std::forward(x.mData); return ck::forward(x.mData);
} }
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
...@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I ...@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!is_same<remove_cvref_t<Y>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size"); "wrong! inconsistent size");
...@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename Y, template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value, typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type = typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
{ {
} }
...@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type; ...@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
} }
// https://en.cppreference.com/w/cpp/utility/tuple/tie // https://en.cppreference.com/w/cpp/utility/tuple/tie
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "functional4.hpp" #include "functional4.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp" #include "is_detected.hpp"
#endif
namespace ck { namespace ck {
...@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& ...@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const Tuple<Y&...>& ty) const Tuple<Y&...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -38,7 +40,7 @@ template <typename... X, typename... Y> ...@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple) ...@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
} }
} }
#ifndef CK_CODE_GEN_RTC
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(ck::declval<T&>().IsTuple());
#endif
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&) __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{ {
#ifndef CK_CODE_GEN_RTC
return (is_detected<is_tuple, Ts>::value || ...); return (is_detected<is_tuple, Ts>::value || ...);
#endif
} }
template <index_t depth = 0, typename T> template <index_t depth = 0, typename T>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp" #include "ck/utility/enable_if.hpp"
#include "ck/utility/enable_if.hpp" #include "ck/utility/integral_constant.hpp"
namespace ck { namespace ck {
#ifdef CK_CODE_GEN_RTC
template <typename X, typename Y> // NOLINTNEXTLINE
struct is_same : public integral_constant<bool, false> #define CK_BUILTIN_TYPE_TRAIT1(name) \
{ template <class T> \
}; struct name : bool_constant<__##name(T)> \
{ \
template <typename X> }
struct is_same<X, X> : public integral_constant<bool, true>
{ // NOLINTNEXTLINE
}; #define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
template <typename X, typename Y> struct name : bool_constant<__##name(T, U)> \
inline constexpr bool is_same_v = is_same<X, Y>::value; { \
}
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type; // NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <typename T> template <class... Ts> \
using remove_cv_t = typename std::remove_cv<T>::type; struct name : bool_constant<__##name(Ts...)> \
{ \
template <typename T> }
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
CK_BUILTIN_TYPE_TRAIT1(is_class);
template <typename T> CK_BUILTIN_TYPE_TRAIT1(is_pointer);
using remove_pointer_t = typename std::remove_pointer<T>::type; CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
template <typename T> CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <class T>
__host__ __device__ constexpr Y bit_cast(const X& x) struct remove_cv
{ {
static_assert(__has_builtin(__builtin_bit_cast), ""); using type = T;
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); };
return __builtin_bit_cast(Y, x); template <class T>
} struct remove_cv<const T> : remove_cv<T>
{
} // namespace ck };
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
typedef T type;
};
template <class T>
struct remove_pointer<T*>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* volatile>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const volatile>
{
typedef T type;
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <class T>
struct is_const : public integral_constant<bool, false>
{
};
template <class T>
struct is_const<const T> : public integral_constant<bool, true>
{
};
template <class T>
inline constexpr bool is_const_v = is_const<T>::value;
template <typename T>
inline constexpr bool is_reference_v = is_reference<T>::value;
template <class T>
struct remove_const
{
typedef T type;
};
template <class T>
struct remove_const<const T>
{
typedef T type;
};
template <class T>
using remove_const_t = typename remove_const<T>::type;
template <class T>
inline constexpr bool is_class_v = is_class<T>::value;
template <class T>
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
// template <typename T>
// T&& declval() noexcept;
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class...>
using void_t = void;
#else
#include <utility>
#include <type_traits>
using std::declval;
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_class_v;
using std::is_const_v;
using std::is_pointer;
using std::is_reference;
using std::is_reference_v;
using std::is_trivially_copyable;
using std::is_trivially_copyable_v;
using std::is_unsigned;
using std::remove_const_t;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::void_t;
#endif
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_floating_point : public integral_constant<bool, false>
{
};
template <>
struct is_floating_point<float> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<double> : public integral_constant<bool, true>
{
};
template <>
struct is_floating_point<long double> : public integral_constant<bool, true>
{
};
template <typename X>
struct is_integral : public integral_constant<bool, false>
{
};
template <>
struct is_integral<int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned int> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned short> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned long long> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<signed char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<unsigned char> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<wchar_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char16_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<char32_t> : public integral_constant<bool, true>
{
};
template <>
struct is_integral<bool> : public integral_constant<bool, true>
{
};
template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename X, typename Y>
inline constexpr bool is_base_of_v = is_base_of<X, Y>::value;
template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename remove_pointer<T>::type;
template <typename T>
inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
} // namespace ck
...@@ -52,10 +52,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -52,10 +52,10 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false> ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -63,13 +63,13 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -63,13 +63,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type. // Convert X to Y, either X or Y is a const data type.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false> ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>; using NonConstY = ck::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>; using NonConstX = ck::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x)); return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
} }
...@@ -149,7 +149,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int ...@@ -149,7 +149,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -211,7 +211,11 @@ template <> ...@@ -211,7 +211,11 @@ template <>
inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -251,7 +255,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x) ...@@ -251,7 +255,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t, return utils::cast_to_f8<half_t,
f8_fnuz_t, f8_fnuz_t,
negative_zero_nan, negative_zero_nan,
...@@ -265,7 +274,11 @@ template <> ...@@ -265,7 +274,11 @@ template <>
inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); #ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#endif
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -307,7 +320,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x ...@@ -307,7 +320,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
#else
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
#endif
return utils::cast_to_f8<half_t, return utils::cast_to_f8<half_t,
bf8_fnuz_t, bf8_fnuz_t,
negative_zero_nan, negative_zero_nan,
...@@ -629,20 +647,22 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x) ...@@ -629,20 +647,22 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif #endif
} }
template <typename Y, typename X, std::size_t NumElems> #ifndef CK_CODE_GEN_RTC
template <typename Y, typename X, size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x) const std::array<X, NumElems>& x)
{ {
for(std::size_t i = 0; i < NumElems; i++) for(size_t i = 0; i < NumElems; i++)
{ {
y[i] = type_convert<Y>(x[i]); y[i] = type_convert<Y>(x[i]);
} }
} }
#endif
template <typename Y, typename X, index_t NumElems> template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x) inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
{ {
for(std::size_t i = 0; i < NumElems; i++) for(size_t i = 0; i < NumElems; i++)
{ {
y[i] = type_convert<Y>(x[i]); y[i] = type_convert<Y>(x[i]);
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#define CK_TILE_MAX_RANK 5 #include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_, template <typename AccDataType_,
typename ODataType_, typename ODataType_,
bool kPadM_, typename CLayout_,
bool kPadN_, index_t kBlockSize_,
bool kTilePermute_, index_t kM_,
index_t kRank_, index_t kN_,
index_t kPerm0, index_t kMWave_,
index_t kPerm1, index_t kNWave_,
index_t TileSize0, index_t kMPerXdl_,
index_t TileSize1, index_t kNPerXdl_,
index_t kPerm2 = 0, index_t kKPerXdl_,
index_t kPerm3 = 0, bool isCTransposed_>
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem struct CShuffleEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_; using CLayout = remove_cvref_t<CLayout_>;
static constexpr bool kPadN = kPadN_; static constexpr index_t kBlockSize = kBlockSize_;
static constexpr bool kTilePermute = kTilePermute_; static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kRank = kRank_; static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; static constexpr index_t kMWave = kMWave_;
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { static constexpr index_t kNWave = kNWave_;
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
}; };
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue struct CShuffleEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM; using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr bool kPadN = Problem::kPadN; static constexpr index_t kBlockSize = Problem::kBlockSize;
const index_t* kPerm = Problem::kPerm; static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr bool kTilePermute = Problem::kTilePermute; static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t kRank = Problem::kRank; static constexpr index_t kMWave = Problem::kMWave;
const index_t* tile_sizes = Problem::tile_sizes; static constexpr index_t kNWave = Problem::kNWave;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
// No additional shared memory needed static constexpr index_t kNPerXdl = Problem::kNPerXdl;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
// TODO: At now CShuffle doesn't allow to vector store after permute. constexpr index_t MaxVectorStoreSize = 16;
// It should be fixed and this function should return true. return MaxVectorStoreSize / sizeof(ODataType);
return false;
} }
template <typename OAccTile> template <typename Problem>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{ {
using DataType = typename OAccTile::DataType; // N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{ {
// Convert linear index to multi-dimensional indices return make_naive_tensor_descriptor(
array<index_t, kRank> indices; make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
index_t remaining = linear_idx; make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
} }
// M is contiguous dimension
// Copy the permuted data back to the original thread buffer else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
for(index_t i = 0; i < total_elements; ++i) {
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
}
else
{ {
thread_buf.set_as(i, permuted_thread_buf.get(i)); static_assert(false, "Unsupported CLayout!");
} }
} }
template <typename ODramWindowTmp, CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
}
template <typename ODramWindow,
typename OAccTile, typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set> memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_DEVICE auto
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
{ {
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin const index_t iMWarp = get_warp_id() / kNWave;
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0}; const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
for(index_t i = 0; i < kRank; ++i)
{ constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; auto o_lds_block = make_tensor_view<address_space_enum::lds>(
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); static_cast<ODataType*>(p_smem), lds_block_desc);
} auto in_lds_window =
make_tile_window(o_lds_block,
typename ODramWindowTmp::BottomTensorIndex step = {}; make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
for(index_t i = 0; i < kRank; ++i) {number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
{ auto out_lds_window =
step[i] = permuted_window_origin[i] - current_window_origin[i]; make_tile_window(o_lds_block,
} make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
constexpr index_t num_access = SFC::get_num_of_access();
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
CWarpTensor c_warp_in_tensor;
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
if constexpr(out_memory_data_op == memory_operation_enum::set) if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); store_tile(out_dram_window, c_out_tensor);
} }
else else
{ {
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); update_tile(out_dram_window, c_out_tensor);
} }
buffer_store_fence(); if constexpr(iAccess != num_access - 1)
}
else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
} }
else });
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem ...@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseRawStore = UseRawStore_;
}; };
template <typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
};
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue struct Default2DEpilogue
{ {
...@@ -35,14 +57,13 @@ struct Default2DEpilogue ...@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, template <typename ODramWindowTmp,
typename OAccTile, typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set> memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{ {
// TODO: this is ugly // TODO: this is ugly
...@@ -71,4 +92,76 @@ struct Default2DEpilogue ...@@ -71,4 +92,76 @@ struct Default2DEpilogue
} }
} }
}; };
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -70,7 +70,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -70,7 +70,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__ static constexpr auto __host__ static constexpr auto
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
{ {
return TilePartitioner::GridSize(M, N, KBatch * batch_count); return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
} }
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
...@@ -101,14 +101,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -101,14 +101,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
{ {
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z / kargs.KBatch); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_k = __builtin_amdgcn_readfirstlane(blockIdx.z - i_batch * kargs.KBatch); const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_k); const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
// options // options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
...@@ -128,7 +128,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep ...@@ -128,7 +128,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1) if(kargs.k_batch == 1)
{ {
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
} }
......
...@@ -75,12 +75,12 @@ struct GemmKernel ...@@ -75,12 +75,12 @@ struct GemmKernel
static constexpr auto I1 = number<1>(); static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>(); static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{ {
return TilePartitioner::GridSize(M, N, KBatch); return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
} }
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
struct GemmKernelArgs struct GemmKernelArgs
{ {
...@@ -93,7 +93,7 @@ struct GemmKernel ...@@ -93,7 +93,7 @@ struct GemmKernel
index_t stride_A; index_t stride_A;
index_t stride_B; index_t stride_B;
index_t stride_C; index_t stride_C;
index_t KBatch; index_t k_batch;
}; };
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
...@@ -121,7 +121,7 @@ struct GemmKernel ...@@ -121,7 +121,7 @@ struct GemmKernel
const std::size_t k_id = blockIdx.z) const std::size_t k_id = blockIdx.z)
{ {
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.KBatch * K1; const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -142,13 +142,13 @@ struct GemmKernel ...@@ -142,13 +142,13 @@ struct GemmKernel
b_k_split_offset = k_id * KRead; b_k_split_offset = k_id * KRead;
} }
if(k_id < static_cast<uint32_t>(kargs.KBatch - 1)) if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{ {
splitted_k = KRead; splitted_k = KRead;
} }
else else
{ {
splitted_k = kargs.K - KRead * (kargs.KBatch - 1); splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
} }
} }
...@@ -159,14 +159,10 @@ struct GemmKernel ...@@ -159,14 +159,10 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
constexpr bool is_output_c_reg_transposed = if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); is_any_of<CDataType, fp16_t, bf16_t>::value)
if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed) ||
!(std::is_same_v<CDataType, fp16_t> || std::is_same_v<CDataType, bf16_t>)))
{ {
if(kargs.KBatch != 1) if(kargs.k_batch != 1)
{ {
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
return false; return false;
...@@ -182,7 +178,7 @@ struct GemmKernel ...@@ -182,7 +178,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeA != 0) if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
{ {
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
...@@ -197,7 +193,7 @@ struct GemmKernel ...@@ -197,7 +193,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeA != 0) if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
return false; return false;
...@@ -213,7 +209,7 @@ struct GemmKernel ...@@ -213,7 +209,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeB != 0) if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
...@@ -228,7 +224,7 @@ struct GemmKernel ...@@ -228,7 +224,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.K % GemmPipeline::VectorSizeB != 0) if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
{ {
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
return false; return false;
...@@ -244,7 +240,7 @@ struct GemmKernel ...@@ -244,7 +240,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % GemmPipeline::VectorSizeC != 0) if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -259,7 +255,7 @@ struct GemmKernel ...@@ -259,7 +255,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % GemmPipeline::VectorSizeC != 0) if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -275,14 +271,6 @@ struct GemmKernel ...@@ -275,14 +271,6 @@ struct GemmKernel
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset) const SplitKBatchOffset& splitk_batch_offset)
{ {
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -290,7 +278,7 @@ struct GemmKernel ...@@ -290,7 +278,7 @@ struct GemmKernel
a_ptr, a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::GetVectorSizeA()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -299,7 +287,7 @@ struct GemmKernel ...@@ -299,7 +287,7 @@ struct GemmKernel
a_ptr, a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M), make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1), make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{}, number<GemmPipeline::GetVectorSizeA()>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -311,7 +299,7 @@ struct GemmKernel ...@@ -311,7 +299,7 @@ struct GemmKernel
b_ptr, b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N), make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::GetVectorSizeB()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -320,7 +308,7 @@ struct GemmKernel ...@@ -320,7 +308,7 @@ struct GemmKernel
b_ptr, b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1), make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{}, number<GemmPipeline::GetVectorSizeB()>{},
number<1>{}); number<1>{});
} }
}(); }();
...@@ -333,7 +321,7 @@ struct GemmKernel ...@@ -333,7 +321,7 @@ struct GemmKernel
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{}, number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -501,22 +489,14 @@ struct GemmKernel ...@@ -501,22 +489,14 @@ struct GemmKernel
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed = EpiloguePipeline{}
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || c_block_window, c_block_tile, smem_ptr);
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{ {
const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
...@@ -531,14 +511,20 @@ struct GemmKernel ...@@ -531,14 +511,20 @@ struct GemmKernel
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
if(kargs.KBatch == 1) if(kargs.k_batch == 1)
{ {
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
} }
else else
{ {
RunGemm<memory_operation_enum::atomic_add>( // Do not compile in case where we have unsupported
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); // VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
}
} }
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
/** @brief Struct representing 2D block index mapping into 3D output tile space. */ /**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template <typename BlockGemmShapeType> template <typename BlockGemmShapeType>
struct GemmTile2DPartitioner struct GemmTile2DPartitioner
{ {
...@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner ...@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
/** @brief Returns 3D grid size. */ CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept = delete;
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) noexcept( CK_TILE_HOST_DEVICE GemmTile2DPartitioner([[maybe_unused]] index_t M,
noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 [[maybe_unused]] index_t N) noexcept;
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
{ {
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock; const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
const index_t GridDimZ = batch_size; return dim3(GridDimX, GridDimY, 1);
return dim3(GridDimX, GridDimY, GridDimZ);
} }
/** /**
* @brief Returns the number of loops. * @brief Calculate number of loop iterations over GEMM's K dimension.
* @param [in] K is dimension *
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/ */
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, KPerBlock); return integer_divide_ceil(K, KPerBlock);
} }
...@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner ...@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
* @param [in] blockIdy is blockIdx.y * @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes. * @return Returns the output tile indexes.
*/ */
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx,
index_t blockIdy) noexcept /**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
-> const tuple<index_t, index_t> -> const tuple<index_t, index_t>
{ {
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
...@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner ...@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
}; };
/** /**
* @brief Struct representing 1D block index mapping into 2D output tile space. * @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/ */
template <typename BlockGemmShapeType> template <typename BlockGemmShape_>
struct GemmTile1DPartitioner struct GemmTile1DPartitioner
{ {
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
/** @brief delete default ctr with no any object */ CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept = delete;
constexpr GemmTile1DPartitioner() noexcept = delete;
/** @brief constructs an object that does contain a N value. */
constexpr GemmTile1DPartitioner(index_t N) noexcept { N_ = N; }
/** @brief Returns 1D grid size. */ /**
CK_TILE_HOST static constexpr auto * @brief Construct a new GemmTile1DPartitioner object.
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3 *
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE GemmTile1DPartitioner([[maybe_unused]] index_t M, index_t N) noexcept
{ {
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; N_ = N;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return dim3(GridDimX * GridDimY, 1, 1);
} }
/** /**
* @brief Returns the number of blocks in N. * @brief Calculates GEMM kernel grid size.
* @param [in] N is dimension *
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/ */
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N) noexcept -> index_t CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{ {
return integer_divide_ceil(N, NPerBlock); const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
return GridDimX * GridDimY;
} }
/** /**
* @brief Returns the number of loops. * @brief Calculate number of loop iterations over GEMM's K dimension.
* @param [in] K is dimension *
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/ */
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) noexcept -> index_t CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{ {
return integer_divide_ceil(K, KPerBlock); return integer_divide_ceil(K, KPerBlock);
} }
/** /**
* @brief The function returns 2D output tile space. * @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
* @param [in] blockIdx is blockIdx.x - block_start. *
* */ * @param blockIdx WGP's index.
CK_TILE_DEVICE static constexpr auto GetOutputTileIndex(index_t blockIdx) noexcept * @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
-> const tuple<index_t, index_t> -> const tuple<index_t, index_t>
{ {
const index_t NBlock = GetNBlock(N_); const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlock); const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - (iM)*NBlock); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
...@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn ...@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed, * enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type. * otherwise std::false_type.
*/ */
template <typename PartitionerFn, template <typename TilePartitioner,
typename = typename std::enable_if_t<HasFnOneArgImpl<PartitionerFn>{}>> typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
struct OffsettedTile1DPartitioner struct OffsettedTile1DPartitioner
{ {
/** /**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes. * @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`. * @param [in] block_start Workgroup offset.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index. * @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/ */
[[nodiscard]] CK_TILE_DEVICE static constexpr auto GetOffsetedTileIndex(index_t block_start, [[nodiscard]] CK_TILE_DEVICE static auto
index_t N) noexcept GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
-> const tuple<index_t, index_t> -> const tuple<index_t, index_t>
{ {
const auto [iM, iN] = PartitionerFn(N).GetOutputTileIndex(blockIdx.x - block_start); const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
return make_tuple(iM, iN); return make_tuple(iM, iN);
} }
}; };
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
struct GemmSpatiallyLocalTilePartitioner
{
using BlockGemmShape = remove_cvref_t<BlockGemmShapeType>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept = delete;
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner(index_t M_, index_t N_) noexcept
: M(M_), N(N_)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
return GridDimX * GridDimY;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
{
return integer_divide_ceil(K, KPerBlock);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
-> const tuple<index_t, index_t>
{
const auto M0 = integer_divide_ceil(M, MPerBlock);
const auto N0 = integer_divide_ceil(N, NPerBlock);
if(M0 == 1)
{
return make_tuple(0, block_1d_id);
}
else if(N0 == 1)
{
return make_tuple(block_1d_id, 0);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
const auto group_id_y = block_1d_id / GroupNum;
const auto group_id_x = block_1d_id - group_id_y * GroupNum;
const auto remap_block_1d_id =
group_id_x <= big_group_num
? group_id_x * group_size + group_id_y
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
const index_t idx_M0 = remap_block_1d_id / N0;
const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
const index_t M0_tmp = M0 / M01;
const index_t M0_mod_M01 = M0 - M0_tmp * M01;
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
const index_t idx_M00 = idx_M0 / M01;
const index_t idx_M01 = idx_M0 - idx_M00 * M01;
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const index_t N_out = idx_N0_M01_local / M01_adapt;
const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
}
}
private:
index_t M;
index_t N;
};
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment