Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b8f4de71
Commit
b8f4de71
authored
Jan 29, 2025
by
Rostyslav Geyyer
Browse files
Add conversions
parent
c98974ee
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
587 additions
and
30 deletions
+587
-30
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+225
-20
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+362
-10
No files found.
include/ck/utility/scaled_type_convert.hpp
View file @
b8f4de71
...
@@ -10,7 +10,11 @@ namespace ck {
...
@@ -10,7 +10,11 @@ namespace ck {
// Declare a template function for scaled conversion
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
#if CK_USE_NATIVE_MX_SUPPORT || CK_USE_OCP_FP8
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#else
__host__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#endif
// convert f8_ocp_t to fp32
// convert f8_ocp_t to fp32
template
<
>
template
<
>
...
@@ -200,27 +204,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
...
@@ -200,27 +204,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
return
out
.
float_1x32
;
return
out
.
float_1x32
;
}
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -231,8 +221,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
...
@@ -231,8 +221,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
float
x
)
#else
inline
__host__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -243,8 +237,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
...
@@ -243,8 +237,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
// convert fp32x2 to fp8x2
// convert fp32x2 to fp8x2
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
float2_t
x
)
#else
inline
__host__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -254,8 +252,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
...
@@ -254,8 +252,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
}
}
// convert fp32x2 to bf8x2
// convert fp32x2 to bf8x2
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
float2_t
x
)
#else
inline
__host__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -267,8 +270,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
...
@@ -267,8 +270,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
// convert fp32x16 to fp8x16
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x16_ocp_t
inline
__host__
__device__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -280,8 +288,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
...
@@ -280,8 +288,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x16 to bf8x16
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x16_ocp_t
inline
__host__
__device__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -293,8 +306,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
...
@@ -293,8 +306,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x32 to fp8x32
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x32_ocp_t
inline
__host__
__device__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -306,8 +324,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
...
@@ -306,8 +324,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
// convert fp32x32 to bf8x32
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x32_ocp_t
inline
__host__
__device__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -316,10 +339,36 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
...
@@ -316,10 +339,36 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif
#endif
}
}
// convert fp4 to fp32
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
#endif
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
f4x2_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
#endif
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
union
union
...
@@ -340,8 +389,12 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
...
@@ -340,8 +389,12 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
// convert vector of 32 fp4 to vector of 32 fp32
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
f4x32_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
#endif
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
union
union
...
@@ -573,7 +626,11 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
...
@@ -573,7 +626,11 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
// convert fp32 to fp4
// convert fp32 to fp4
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -584,8 +641,12 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t sca
...
@@ -584,8 +641,12 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_bexp_t sca
// convert vector of 2 fp32 to vector of 2 fp4
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
float2_t
x
)
#else
inline
__host__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -596,8 +657,12 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bex
...
@@ -596,8 +657,12 @@ inline __host__ __device__ f4x2_t scaled_type_convert<f4x2_t, float2_t>(e8m0_bex
// convert vector of 32 fp32 to vector of 32 fp4
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
float32_t
x
)
#else
inline
__host__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -615,10 +680,61 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
...
@@ -615,10 +680,61 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
* @return The converted 32-bit float representation of the input.
* @return The converted 32-bit float representation of the input.
*/
*/
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
#endif
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
#endif
}
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
#endif
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
scale
,
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
/**
/**
...
@@ -630,10 +746,61 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
...
@@ -630,10 +746,61 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
* @return The converted 32-bit float representation of the input.
* @return The converted 32-bit float representation of the input.
*/
*/
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
#endif
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
#endif
}
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
#endif
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
scale
,
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
/**
/**
...
@@ -648,7 +815,26 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
...
@@ -648,7 +815,26 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
* @return The converted 6-bit floating-point value (f6_t).
* @return The converted 6-bit floating-point value (f6_t).
*/
*/
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
f6_t
scaled_type_convert
<
f6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
inline
__host__
__device__
f6_t
scaled_type_convert
<
f6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f6_t
scaled_type_convert
<
f6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F6_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -669,7 +855,26 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
...
@@ -669,7 +855,26 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
* @return The converted 6-bit floating-point value (bf6_t).
* @return The converted 6-bit floating-point value (bf6_t).
*/
*/
template
<
>
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
bf6_t
scaled_type_convert
<
bf6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
inline
__host__
__device__
bf6_t
scaled_type_convert
<
bf6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
bf6_t
scaled_type_convert
<
bf6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
template
<
>
#if CK_USE_NATIVE_MX_SUPPORT
inline
__host__
__device__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F6_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
...
...
include/ck/utility/type_convert.hpp
View file @
b8f4de71
...
@@ -1400,8 +1400,79 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
...
@@ -1400,8 +1400,79 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
*/
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
],
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
],
x
[
8
],
x
[
9
],
x
[
10
],
x
[
11
],
x
[
12
],
x
[
13
],
x
[
14
],
x
[
15
]};
float16_t
in2
=
{
x
[
16
],
x
[
17
],
x
[
18
],
x
[
19
],
x
[
20
],
x
[
21
],
x
[
22
],
x
[
23
],
x
[
24
],
x
[
25
],
x
[
26
],
x
[
27
],
x
[
28
],
x
[
29
],
x
[
30
],
x
[
31
]};
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
}
/**
/**
...
@@ -1418,15 +1489,65 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
...
@@ -1418,15 +1489,65 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
}
/**
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
* (f6_t).
*
*
* Depending on the CK_USE_SR_F
4
_CONVERSION flag,
* Depending on the CK_USE_SR_F
6
_CONVERSION flag,
* the conversion uses stochastic rounding
* the conversion uses stochastic rounding
* or round-to-nearest-even.
* or round-to-nearest-even.
*
*
...
@@ -1436,7 +1557,17 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
...
@@ -1436,7 +1557,17 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
template
<
>
template
<
>
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if defined(__gfx950__)
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if defined(__gfx950__)
return
f6_convert_sr
(
x
);
return
f6_convert_sr
(
x
);
#else
#else
return
f6_convert_rne
(
x
);
return
f6_convert_rne
(
x
);
...
@@ -1455,8 +1586,53 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
...
@@ -1455,8 +1586,53 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
/**
/**
...
@@ -1471,8 +1647,79 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
...
@@ -1471,8 +1647,79 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
*/
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
[
0
],
x
[
1
],
x
[
2
],
x
[
3
],
x
[
4
],
x
[
5
],
x
[
6
],
x
[
7
],
x
[
8
],
x
[
9
],
x
[
10
],
x
[
11
],
x
[
12
],
x
[
13
],
x
[
14
],
x
[
15
]};
float16_t
in2
=
{
x
[
16
],
x
[
17
],
x
[
18
],
x
[
19
],
x
[
20
],
x
[
21
],
x
[
22
],
x
[
23
],
x
[
24
],
x
[
25
],
x
[
26
],
x
[
27
],
x
[
28
],
x
[
29
],
x
[
30
],
x
[
31
]};
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
}
/**
/**
...
@@ -1490,14 +1737,64 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
...
@@ -1490,14 +1737,64 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
}
/**
/**
* @brief Specializes float-to-bf6_t conversion.
* @brief Specializes float-to-bf6_t conversion.
*
*
* Uses stochastic rounding if CK_USE_SR_F
4
_CONVERSION is defined,
* Uses stochastic rounding if CK_USE_SR_F
6
_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
* otherwise uses round-to-nearest-even.
*
*
* @param x Input float value to convert.
* @param x Input float value to convert.
...
@@ -1506,7 +1803,17 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
...
@@ -1506,7 +1803,17 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
template
<
>
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
return
bf6_convert_sr
(
x
);
#else
#else
return
bf6_convert_rne
(
x
);
return
bf6_convert_rne
(
x
);
...
@@ -1525,8 +1832,53 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
...
@@ -1525,8 +1832,53 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment