Commit d032ea56 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add docstrings

parent 97c7e725
...@@ -681,6 +681,14 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc ...@@ -681,6 +681,14 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
#endif #endif
} }
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template <> template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale, inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m0_bexp_t scale,
f6x32_t x) f6x32_t x)
...@@ -739,6 +747,14 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s ...@@ -739,6 +747,14 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
#endif #endif
} }
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template <> template <>
inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale, inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8m0_bexp_t scale,
bf6x32_t x) bf6x32_t x)
...@@ -786,6 +802,17 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca ...@@ -786,6 +802,17 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
#endif #endif
} }
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template <> template <>
inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale, inline __host__ __device__ f6x32_t scaled_type_convert<f6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x) float32_t x)
...@@ -818,6 +845,17 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s ...@@ -818,6 +845,17 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s
#endif #endif
} }
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template <> template <>
inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale, inline __host__ __device__ bf6x32_t scaled_type_convert<bf6x32_t, float32_t>(e8m0_bexp_t scale,
float32_t x) float32_t x)
......
...@@ -1418,6 +1418,16 @@ inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f) ...@@ -1418,6 +1418,16 @@ inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
#endif #endif
} }
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f) inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
...@@ -1510,6 +1520,16 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) ...@@ -1510,6 +1520,16 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
#endif #endif
} }
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f) inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
...@@ -1557,17 +1577,28 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f ...@@ -1557,17 +1577,28 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
template <> template <>
inline __host__ __device__ f6_t type_convert<f6_t, float>(float x) inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
{ {
#if defined(__gfx950__) #if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x); return f6_convert_sr(x);
#else #else
return f6_convert_rne(x); return f6_convert_rne(x);
#endif #endif
} }
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template <> template <>
inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x) inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
{ {
#if defined(__gfx950__) #if CK_USE_SR_F6_CONVERSION
return f6_convert_sr(x); return f6_convert_sr(x);
#else #else
return f6_convert_rne(x); return f6_convert_rne(x);
...@@ -1607,6 +1638,15 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x) ...@@ -1607,6 +1638,15 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
#endif #endif
} }
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template <> template <>
inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x) inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{ {
...@@ -1665,6 +1705,17 @@ inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f) ...@@ -1665,6 +1705,17 @@ inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
#endif #endif
} }
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f) inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
{ {
#if defined(__gfx950__) #if defined(__gfx950__)
...@@ -1758,6 +1809,18 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) ...@@ -1758,6 +1809,18 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
#endif #endif
} }
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f) inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
...@@ -1810,6 +1873,15 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x) ...@@ -1810,6 +1873,15 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
#endif #endif
} }
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template <> template <>
inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x) inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t x)
{ {
...@@ -1853,6 +1925,16 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x) ...@@ -1853,6 +1925,16 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
#endif #endif
} }
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template <> template <>
inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x) inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t x)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment