Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d032ea56
Commit
d032ea56
authored
Jan 30, 2025
by
Rostyslav Geyyer
Browse files
Add docstrings
parent
97c7e725
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
2 deletions
+122
-2
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+38
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+84
-2
No files found.
include/ck/utility/scaled_type_convert.hpp
View file @
d032ea56
...
...
@@ -681,6 +681,14 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
...
...
@@ -739,6 +747,14 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
...
...
@@ -786,6 +802,17 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template
<
>
inline
__host__
__device__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
...
...
@@ -818,6 +845,17 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template
<
>
inline
__host__
__device__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
...
...
include/ck/utility/type_convert.hpp
View file @
d032ea56
...
...
@@ -1418,6 +1418,16 @@ inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
...
...
@@ -1510,6 +1520,16 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
...
...
@@ -1557,17 +1577,28 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
template
<
>
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
{
#if
defined(__gfx950__)
#if
CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if
defined(__gfx950__)
#if
CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
...
...
@@ -1607,6 +1638,15 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
...
...
@@ -1665,6 +1705,17 @@ inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
...
...
@@ -1758,6 +1809,18 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
...
...
@@ -1810,6 +1873,15 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
...
...
@@ -1853,6 +1925,16 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment