Commit aa1920da authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add fp4 vectors

parent 9433306a
......@@ -1191,62 +1191,6 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
}
};
template <>
struct NumericLimits<f4_t>
{
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
static constexpr float data_max_normal_number = 6;
static constexpr float data_min_subnormal_number = 0.5;
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
template <>
struct NumericLimits<e8m0_scale_t>
{
static constexpr e8m0_scale_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_scale_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_scale_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_scale_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_scale_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_scale_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_scale_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_scale_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_scale_t Min() { return e8m0_scale_t(binary_min); }
__host__ __device__ static constexpr e8m0_scale_t Max() { return e8m0_scale_t(binary_max); }
__host__ __device__ static constexpr e8m0_scale_t QuietNaN()
{
return e8m0_scale_t(binary_qnan);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_1() { return e8m0_scale_t(binary_1); }
__host__ __device__ static constexpr e8m0_scale_t Binary_2() { return e8m0_scale_t(binary_2); }
__host__ __device__ static constexpr e8m0_scale_t Binary_3() { return e8m0_scale_t(binary_3); }
__host__ __device__ static constexpr e8m0_scale_t Binary_135()
{
return e8m0_scale_t(binary_135);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_142()
{
return e8m0_scale_t(binary_142);
}
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
......@@ -1643,6 +1587,136 @@ struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
}
};
template <>
struct vector_type<f4_t, 2, typename std::enable_if_t<!is_native_type<f4_t>()>>
{
using d1_t = f4_t;
using d2_t = uint8_t;
using type = d2_t;
union alignas(next_pow2(sizeof(type)))
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
template <>
struct vector_type<f4_t, 4, typename std::enable_if_t<!is_native_type<f4_t>()>>
{
using d1_t = f4_t;
using d2_t = uint8_t;
using d4_t = uint16_t;
using type = d4_t;
union alignas(next_pow2(sizeof(type)))
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
using int64_t = long;
// fp64
......@@ -1805,6 +1879,62 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
};
template <>
struct NumericLimits<f4_t>
{
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
static constexpr float data_max_normal_number = 6;
static constexpr float data_min_subnormal_number = 0.5;
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
template <>
struct NumericLimits<e8m0_scale_t>
{
static constexpr e8m0_scale_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_scale_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_scale_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_scale_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_scale_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_scale_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_scale_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_scale_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_scale_t Min() { return e8m0_scale_t(binary_min); }
__host__ __device__ static constexpr e8m0_scale_t Max() { return e8m0_scale_t(binary_max); }
__host__ __device__ static constexpr e8m0_scale_t QuietNaN()
{
return e8m0_scale_t(binary_qnan);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_1() { return e8m0_scale_t(binary_1); }
__host__ __device__ static constexpr e8m0_scale_t Binary_2() { return e8m0_scale_t(binary_2); }
__host__ __device__ static constexpr e8m0_scale_t Binary_3() { return e8m0_scale_t(binary_3); }
__host__ __device__ static constexpr e8m0_scale_t Binary_135()
{
return e8m0_scale_t(binary_135);
}
__host__ __device__ static constexpr e8m0_scale_t Binary_142()
{
return e8m0_scale_t(binary_142);
}
};
template <typename T>
struct NumericUtils
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment