Commit 5e28c17a authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add mfma selection

parent 29eaa2dc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -716,6 +716,12 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
static constexpr auto GetMfma<custom_half_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
static constexpr auto GetMfma<half_t, 16, 16>()
{
......
......@@ -17,8 +17,10 @@ struct custom_half_t
{
using type = short;
type data;
custom_half_t() : data{type{}} {}
custom_half_t(type init) : data{init} {}
__host__ __device__ constexpr custom_half_t() : data{type{}} {}
__host__ __device__ constexpr custom_half_t(type init) : data{init} {}
__host__ __device__ constexpr custom_half_t(int init) : data{static_cast<type>(init)} {}
__host__ __device__ constexpr custom_half_t(float init) : data{static_cast<type>(init)} {}
};
inline constexpr auto next_pow2(uint32_t x)
......@@ -37,6 +39,22 @@ inline constexpr bool is_native_type()
is_same<T, bool>::value;
}
template <typename T, index_t N>
struct non_native_vector_base
{
using VecT = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const VecT&) = default;
__host__ __device__ non_native_vector_base(VecT&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// vector_type
template <typename T, index_t N, typename Enable = void>
struct vector_type;
......@@ -114,7 +132,13 @@ struct scalar_type<vector_type<T, N>>
static constexpr index_t vector_size = N;
};
//
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>
{
using type = T;
static constexpr index_t vector_size = N;
};
template <>
struct scalar_type<double>
{
......@@ -1021,22 +1045,6 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
}
};
template <typename T, index_t N>
struct non_native_vector_base
{
using VecT = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const VecT&) = default;
__host__ __device__ non_native_vector_base(VecT&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
......@@ -1631,6 +1639,14 @@ using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// custom fp16
using custom_half2_t = typename vector_type<custom_half_t, 2>::type;
using custom_half4_t = typename vector_type<custom_half_t, 4>::type;
using custom_half8_t = typename vector_type<custom_half_t, 8>::type;
using custom_half16_t = typename vector_type<custom_half_t, 16>::type;
using custom_half32_t = typename vector_type<custom_half_t, 32>::type;
using custom_half64_t = typename vector_type<custom_half_t, 64>::type;
// bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment