Commit 5e28c17a authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add mfma selection

parent 29eaa2dc
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -716,6 +716,12 @@ struct MfmaSelector ...@@ -716,6 +716,12 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <>
static constexpr auto GetMfma<custom_half_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <> template <>
static constexpr auto GetMfma<half_t, 16, 16>() static constexpr auto GetMfma<half_t, 16, 16>()
{ {
......
...@@ -17,8 +17,10 @@ struct custom_half_t ...@@ -17,8 +17,10 @@ struct custom_half_t
{ {
using type = short; using type = short;
type data; type data;
custom_half_t() : data{type{}} {} __host__ __device__ constexpr custom_half_t() : data{type{}} {}
custom_half_t(type init) : data{init} {} __host__ __device__ constexpr custom_half_t(type init) : data{init} {}
__host__ __device__ constexpr custom_half_t(int init) : data{static_cast<type>(init)} {}
__host__ __device__ constexpr custom_half_t(float init) : data{static_cast<type>(init)} {}
}; };
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
...@@ -37,6 +39,22 @@ inline constexpr bool is_native_type() ...@@ -37,6 +39,22 @@ inline constexpr bool is_native_type()
is_same<T, bool>::value; is_same<T, bool>::value;
} }
template <typename T, index_t N>
struct non_native_vector_base
{
using VecT = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const VecT&) = default;
__host__ __device__ non_native_vector_base(VecT&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// vector_type // vector_type
template <typename T, index_t N, typename Enable = void> template <typename T, index_t N, typename Enable = void>
struct vector_type; struct vector_type;
...@@ -114,7 +132,13 @@ struct scalar_type<vector_type<T, N>> ...@@ -114,7 +132,13 @@ struct scalar_type<vector_type<T, N>>
static constexpr index_t vector_size = N; static constexpr index_t vector_size = N;
}; };
// template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
template <> template <>
struct scalar_type<double> struct scalar_type<double>
{ {
...@@ -1021,22 +1045,6 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>> ...@@ -1021,22 +1045,6 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
} }
}; };
template <typename T, index_t N>
struct non_native_vector_base
{
using VecT = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const VecT&) = default;
__host__ __device__ non_native_vector_base(VecT&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type implementation // non-native vector_type implementation
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
...@@ -1631,6 +1639,14 @@ using half16_t = typename vector_type<half_t, 16>::type; ...@@ -1631,6 +1639,14 @@ using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type; using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type; using half64_t = typename vector_type<half_t, 64>::type;
// custom fp16
using custom_half2_t = typename vector_type<custom_half_t, 2>::type;
using custom_half4_t = typename vector_type<custom_half_t, 4>::type;
using custom_half8_t = typename vector_type<custom_half_t, 8>::type;
using custom_half16_t = typename vector_type<custom_half_t, 16>::type;
using custom_half32_t = typename vector_type<custom_half_t, 32>::type;
using custom_half64_t = typename vector_type<custom_half_t, 64>::type;
// bfp16 // bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type; using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type; using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment