Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5f1a24a8
Commit
5f1a24a8
authored
Nov 06, 2024
by
Rostyslav Geyyer
Browse files
Add device conversions
parent
1bca7134
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
50 deletions
+45
-50
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+45
-50
No files found.
include/ck/utility/type_convert.hpp
View file @
5f1a24a8
...
...
@@ -504,52 +504,41 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
}
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
)
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
in
.
template
AsType
<
float
>()(
Number
<
0
>
{}),
in
.
template
AsType
<
float
>()(
Number
<
1
>
{}),
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
);
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
)
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
// union
// {
// float fval;
// uint32_t i32val;
// uint8_t i8val[4]; // not endian independent
// } val;
// val.fval = x;
// uint32_t ival = 0;
// const float max_fp8 = 240.0f;
// // if x is not +/- infinity or nan
// if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// // clip float value
// val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
// ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false ->
// WORD0 val.i32val = ival; return val.i8val[0];
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
,
rng
);
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
...
...
@@ -568,12 +557,10 @@ inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
data
)
{
#if defined(__gfx94__)
// float fval;
// uint32_t i32val = static_cast<uint32_t>(x);
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// return fval;
#if defined(__gfx950__)
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
data
,
scale
,
0
)
.
template
AsType
<
float
>()(
Number
<
0
>
{});
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
data
);
#endif
...
...
@@ -585,16 +572,24 @@ __host__ __device__ constexpr Y scaled_type_convert(e8m0_scale_t scale, X x);
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_scale_t
scale
,
f4_t
data
)
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_scale_t
scale
,
f4_t
x
)
{
#if defined(__gfx94__)
// float fval;
// uint32_t i32val = static_cast<uint32_t>(x);
// fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
// return fval;
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
)
.
template
AsType
<
float
>()(
Number
<
0
>
{});
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_scale_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
utils
::
to_float
<
f4_
t
>
(
scale
,
data
);
return
f4_convert_rne
(
x
,
type_convert
<
floa
t
>
(
scale
)
);
#endif
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment