Commit 546a764e authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'migraphx' into uif2-migraphx

parents 8da3dfff 57cdd70b
......@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
[&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
[&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
......
......@@ -5,7 +5,22 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
#endif // __HIPCC_RTC__
namespace ck {
#ifdef __HIPCC_RTC__
using byte = unsigned char;
#else
using std::byte;
#endif
using bhalf_t = ushort;
using half_t = _Float16;
......@@ -974,20 +989,96 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
template <typename T>
struct NumericLimits
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr T QuietNaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
};
template <>
......
......@@ -3,7 +3,7 @@
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace ck {
namespace debug {
......
......@@ -4,11 +4,26 @@
#pragma once
namespace ck {
#ifdef __HIPCC_RTC__
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
#else
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
#endif
} // namespace ck
......@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
return ck::forward<X>(x);
}
else
{
return std::forward<Y>(y);
return ck::forward<Y>(y);
}
}
......
......@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...);
}
};
......@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...);
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
ck::forward<Y>(y).At(Number<Js>{})...);
}
};
......@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
ck::forward<F>(f), ck::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
......@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
ck::forward<F>(f), ck::forward<X>(x), ck::forward<Y>(y));
}
} // namespace ck
......
......@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return integral_constant<decltype(X % Y), X % Y>{};
}
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
} // namespace ck
......@@ -9,6 +9,8 @@
#include "type.hpp"
#include "tuple.hpp"
#define INT32_MAX 2147483647
namespace ck {
// magic number division
......
......@@ -14,6 +14,7 @@
namespace ck {
namespace math {
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); };
......@@ -183,7 +184,7 @@ inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
......
......@@ -2,12 +2,13 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck/utility/ignore.hpp>
namespace ck {
// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
......@@ -23,7 +24,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
......@@ -40,12 +41,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
ck::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
std::ignore = id;
std::ignore = val;
std::ignore = seed;
ck::ignore = id;
ck::ignore = val;
ck::ignore = seed;
return 0;
}
......
......@@ -32,7 +32,7 @@ struct TupleElementKeyData
template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
{
}
......@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return std::forward(x.mData);
return ck::forward(x.mData);
}
template <typename Indices, typename... Xs>
......@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
: TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
{
}
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
: TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
{
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
......@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
__host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
{
}
template <typename... Ys,
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
{
}
......@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
......
......@@ -28,7 +28,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
[&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
......
......@@ -4,10 +4,122 @@
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace ck {
#ifdef __HIPCC_RTC__
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
CK_BUILTIN_TYPE_TRAIT1(is_class);
CK_BUILTIN_TYPE_TRAIT1(is_pointer);
CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <class T>
struct remove_cv
{
using type = T;
};
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
typedef T type;
};
template <class T>
struct remove_pointer<T*>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* volatile>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const volatile>
{
typedef T type;
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
#else
#include <utility>
#include <type_traits>
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_pointer;
using std::is_reference;
using std::is_trivially_copyable;
using std::is_unsigned;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
#endif
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
......@@ -19,25 +131,39 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename T>
inline constexpr bool is_reference_v = is_reference<T>::value;
template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename X, typename Y>
inline constexpr bool is_base_of_v = is_base_of<X, Y>::value;
template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
using remove_cv_t = typename remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
using remove_pointer_t = typename remove_pointer<T>::type;
template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
template <typename Y,
typename X,
typename ck::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
......
......@@ -15,7 +15,7 @@ template <typename Y,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x);
}
......@@ -356,7 +356,7 @@ template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
......@@ -392,7 +392,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/utility)
if (CK_BUILD_JIT_LIB)
add_subdirectory(src/jit_library)
else()
add_subdirectory(src/tensor_operation_instance/gpu)
add_subdirectory(src/utility)
endif()
include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${PROJECT_SOURCE_DIR}/include/ck/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
message(STATUS "RELATIVE: ${PROJECT_SOURCE_DIR}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${PROJECT_SOURCE_DIR}/include)
execute_process(
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/util/make_instance_strings.py
${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
${CMAKE_CURRENT_BINARY_DIR}/solution_instances
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../tensor_operation_instance/gpu/
)
add_library(jit_library STATIC
src/device_batched_gemm_softmax_gemm.cpp
src/device_gemm_multiple_d.cpp
src/common.cpp
)
add_library(composable_kernel::jit_library ALIAS jit_library)
set_target_properties(jit_library PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(jit_library SYSTEM PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
)
target_link_libraries(jit_library PRIVATE $<BUILD_INTERFACE:ck_headers>)
rocm_install(
TARGETS jit_library
EXPORT jit_libraryTargets
)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
rocm_install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
rocm_install(
EXPORT jit_libraryTargets
FILE composable_kerneljit_libraryTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <string_view>
#include <utility>
#include <unordered_map>
namespace ck {
namespace host {
struct Solution
{
std::string template_str;
std::size_t block_size;
std::size_t grid_size;
};
enum class DataType
{
Half,
Float,
Int8,
Int32
};
std::string ToString(DataType dt);
std::unordered_map<std::string_view, std::string_view> GetHeaders();
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string B1ElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string AccElementOp = "ck::tensor_operation::element_wise::Scale";
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
static const std::size_t DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx = 0;
static const std::size_t ALayout_idx = 1;
static const std::size_t B0Layout_idx = 2;
static const std::size_t B1Layout_idx = 3;
static const std::size_t CLayout_idx = 4;
static const std::size_t ADataType_idx = 5;
static const std::size_t B0DataType_idx = 6;
static const std::size_t B1DataType_idx = 7;
static const std::size_t CDataType_idx = 8;
static const std::size_t AccDataType_idx = 9;
static const std::size_t CShuffleDataType_idx = 10;
static const std::size_t AElementwiseOperation_idx = 11;
static const std::size_t B0ElementwiseOperation_idx = 12;
static const std::size_t Acc0ElementwiseOperation_idx = 13;
static const std::size_t B1ElementwiseOperation_idx = 14;
static const std::size_t CElementwiseOperation_idx = 15;
static const std::size_t GEMMSpecialization_idx = 16;
static const std::size_t NumGemmKPrefetchStage_idx = 17;
static const std::size_t BlockSize_idx = 18;
static const std::size_t Gemm01MPerBlock_idx = 19;
static const std::size_t Gemm0NPerBlock_idx = 20;
static const std::size_t Gemm0KPerBlock_idx = 21;
static const std::size_t Gemm1NPerBlock_idx = 22;
static const std::size_t Gemm1KPerBlock_idx = 23;
static const std::size_t AK1_idx = 24;
static const std::size_t BK1_idx = 25;
static const std::size_t B1K1_idx = 26;
static const std::size_t MPerXDL_idx = 27;
static const std::size_t NPerXDL_idx = 28;
static const std::size_t Gemm0MXdlPerWave_idx = 29;
static const std::size_t Gemm0NXdlPerWave_idx = 30;
static const std::size_t Gemm1NXdlPerWave_idx = 31;
static const std::size_t ABlockTransferThreadClusterLengths_K0_M_K1_idx = 32;
static const std::size_t ABlockTransferThreadClusterArrangeOrder_idx = 33;
static const std::size_t ABlockTransferSrcAccessOrder_idx = 34;
static const std::size_t ABlockTransferSrcVectorDim_idx = 35;
static const std::size_t ABlockTransferSrcScalarPerVector_idx = 36;
static const std::size_t ABlockTransferDstScalarPerVector_K1_idx = 37;
static const std::size_t ABlockLdsAddExtraM_idx = 38;
static const std::size_t B0BlockTransferThreadClusterLengths_K0_N_K1_idx = 39;
static const std::size_t B0BlockTransferThreadClusterArrangeOrder_idx = 40;
static const std::size_t B0BlockTransferSrcAccessOrder_idx = 41;
static const std::size_t B0BlockTransferSrcVectorDim_idx = 42;
static const std::size_t B0BlockTransferSrcScalarPerVector_idx = 43;
static const std::size_t B0BlockTransferDstScalarPerVector_K1_idx = 44;
static const std::size_t B0BlockLdsAddExtraN_idx = 45;
static const std::size_t B1BlockTransferThreadClusterLengths_K0_N_K1_idx = 46;
static const std::size_t B1BlockTransferThreadClusterArrangeOrder_idx = 47;
static const std::size_t B1BlockTransferSrcAccessOrder_idx = 48;
static const std::size_t B1BlockTransferSrcVectorDim_idx = 49;
static const std::size_t B1BlockTransferSrcScalarPerVector_idx = 50;
static const std::size_t B1BlockTransferDstScalarPerVector_K1_idx = 51;
static const std::size_t B1BlockLdsAddExtraN_idx = 52;
static const std::size_t CShuffleMXdlPerWavePerShuffle_idx = 53;
static const std::size_t CShuffleNXdlPerWavePerShuffle_idx = 54;
static const std::size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx = 55;
static const std::size_t CBlockTransferScalarPerVector_NWaveNPerXdl_idx = 56;
static const std::size_t MaskOutUpperTriangle_idx = 57;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace ck {
namespace host {
namespace device_gemm_multiple_d {
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
static const std::size_t ds_layout_idx = 3;
static const std::size_t ds_data_type_idx = 9;
static const std::size_t e_data_type_idx = 10;
static const std::size_t a_elementwise_op_idx = 11;
static const std::size_t b_elementwise_op_idx = 12;
static const std::size_t ds_elementwise_op_idx = 13;
static const std::size_t gemm_spec_idx = 14;
static const std::size_t block_size_idx = 16;
static const std::size_t m_per_block_idx = 17;
static const std::size_t n_per_block_idx = 18;
static const std::size_t k_per_block_idx = 19;
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
} // namespace host
} // namespace ck
#include "ck/host/common.hpp"
#include "ck_headers.hpp"
#include <stdexcept>
#include <algorithm>
namespace ck {
namespace host {
std::string ToString(DataType dt)
{
switch(dt)
{
case DataType::Float: return "float";
case DataType::Half: return "ck::half_t";
case DataType::Int8: return "int8_t";
case DataType::Int32: return "int32_t";
}
throw std::runtime_error("Incorrect data type");
}
std::unordered_map<std::string_view, std::string_view> GetHeaders()
{
auto headers = ck_headers();
headers.insert(
{"ck/config.h", ""});
return headers;
}
std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
{
return (x + y - std::size_t{1}) / y;
}
} // namespace host
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment