Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2a30cfdd
Unverified
Commit
2a30cfdd
authored
Feb 12, 2025
by
arai713
Committed by
GitHub
Feb 12, 2025
Browse files
Merge branch 'develop' into codegen-enable-hiprtc
parents
9533a172
78195ccc
Changes
435
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4034 additions
and
202 deletions
+4034
-202
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+1
-1
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+1
-1
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+6
-1
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+5
-3
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+10
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+5
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+9
-6
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+109
-0
include/ck/utility/mxf6_utils.hpp
include/ck/utility/mxf6_utils.hpp
+325
-0
include/ck/utility/mxf8_utils.hpp
include/ck/utility/mxf8_utils.hpp
+570
-0
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+384
-0
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+14
-13
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+877
-0
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+9
-1
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+4
-2
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+18
-23
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+1
-1
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+13
-2
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+118
-24
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1555
-122
No files found.
Too many changes to show.
To preserve performance only
435 of 435+
files are displayed.
Plain diff
Email patch
include/ck/utility/functional.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/utility/functional4.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
...
...
include/ck/utility/integral_constant.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
}
// namespace ck
include/ck/utility/is_detected.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
ck
::
false_type
;
using
value_t
=
integral_constant
<
bool
,
false
>
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
ck
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
ck
::
true_type
;
using
value_t
=
integral_constant
<
bool
,
true
>
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
...
...
include/ck/utility/loop_scheduler.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#endif
#include "ck/utility/common_header.hpp"
...
...
@@ -28,6 +35,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
}
// namespace ck
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
LoopScheduler
&
s
)
{
switch
(
s
)
...
...
@@ -39,3 +47,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
return
os
;
}
#endif
#endif
include/ck/utility/magic_division.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,6 +10,10 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace
ck
{
// magic number division
...
...
include/ck/utility/math_v2.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
...
...
@@ -20,7 +24,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -81,7 +85,7 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__host__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
...
...
@@ -461,7 +465,6 @@ inline __host__ double expm1<double>(double x)
return
std
::
expm1
(
x
);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
...
...
@@ -533,7 +536,7 @@ static inline __device__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
(
x
&
0x80
);
};
static
inline
__device__
bool
isnan
(
f8_t
x
)
{
return
ck
::
fp8_is_nan
(
x
);
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
...
...
@@ -613,7 +616,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template
<
>
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
{
return
__hneg
(
x
);
return
__hneg
(
static_cast
<
__half
>
(
x
)
);
};
template
<
typename
T
>
...
...
include/ck/utility/mxf4_utils.hpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
{
// no inf representation for ocp_e2m1_mxfp4
return
false
;
}
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f4_t
result
=
(
data
&
0b00001111
)
&
NumericUtils
<
f4_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f4_t
>
(
scale
,
data
))
return
0.0
f
;
f4_t
prepared_data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type
<
f4_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type_sr
<
f4_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf6_utils.hpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
f6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for fp6
return
false
;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
bf6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
bf6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for bf6
return
false
;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
f6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
bf6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
bf6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f6_t
>
(
scale
,
data
))
return
0.0
f
;
f6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
bf6_t
>
(
scale
,
data
))
return
0.0
f
;
bf6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
bf6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type
<
f6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type
<
f6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type
<
bf6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type
<
bf6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type_sr
<
f6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type_sr
<
f6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type_sr
<
bf6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type_sr
<
bf6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf8_utils.hpp
0 → 100644
View file @
2a30cfdd
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
namespace
fp8_impl
{
#if CK_MX_FP8_CVT_FAST_PATH
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8_scaled
(
float
scale
,
fp8_storage_t
v
)
{
union
{
unsigned
int
i32val
;
unsigned
char
i8val
[
4
];
}
val
;
val
.
i8val
[
0
]
=
v
;
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_f32_fp8
(
val
.
i32val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_f32_bf8
(
val
.
i32val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float2_t
cast_to_f32x2_from_f8x2_scaled
(
float
scale
,
fp8x2_storage_t
v
)
{
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp8
(
i16val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_bf8
(
i16val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8_storage_t
cast_to_f8_from_f32_scaled
(
float
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
fp8_storage_t
i8data
;
union
{
float
fval
;
unsigned
int
i32val
;
}
val
;
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
fp8_storage_t
v4i8
[
4
];
}
ret
{};
// unsigned int ival = 0;
val
.
fval
=
v
;
if
constexpr
(
stochastic_rounding
)
{
ret
.
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
)
:
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
);
i8data
=
ret
.
v4i8
[
0
];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
i8data
=
ret
.
v4i8
[
0
];
}
return
i8data
;
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8x2_storage_t
cast_to_f8_from_f32_scaled
(
float2_t
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
StaticallyIndexedArray
<
fp8x2_storage_t
,
2
>
v2f8x2
;
}
ret
{};
if
constexpr
(
stochastic_rounding
)
{
fp8x2_storage_t
f8x2
;
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
else
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
return
f8x2
;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
return
ret
.
v2f8x2
(
Number
<
0
>
{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
{
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
{
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
}
// namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_sr
(
X
x
,
float
scale
);
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_rne
(
X
x
,
float
scale
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_rne
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_rne
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_rne
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_rne
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_rne
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_rne
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_rne
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_rne
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_rne
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_sr
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_sr
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_sr
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_sr
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_sr
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_sr
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_sr
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_sr
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_sr
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
}
// namespace ck
include/ck/utility/mxfp_utils.hpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
::
utils
{
union
cvt
{
float
value_float
;
uint32_t
value_bitwise
;
};
template
<
typename
DTYPE
>
inline
bool
getDataHasInf
()
{
return
DTYPE
::
dataInfo
.
hasInf
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
{
x
>>=
NumericUtils
<
T
>::
mant
;
x
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
);
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_subnormal
(
T
x
)
{
return
get_exponent_value
<
T
>
(
x
)
==
0
;
}
template
<
typename
T
>
__host__
__device__
inline
double
get_mantissa_value
(
T
x
)
{
double
mantissa
=
is_subnormal
<
T
>
(
x
)
?
0.0
f
:
1.0
f
;
for
(
uint
i
=
0
;
i
<
NumericUtils
<
T
>::
mant
;
i
++
)
{
mantissa
+=
std
::
pow
(
2
,
-
int32_t
((
NumericUtils
<
T
>::
mant
-
i
)))
*
(
x
&
0b1
);
x
>>=
1
;
}
return
mantissa
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
get_data_has_inf
()
{
return
NumericUtils
<
T
>::
has_inf
;
}
template
<
typename
T
>
__host__
__device__
float
convert_to_float
(
T
data
,
int
scale_exp
)
{
float
d_sign
=
std
::
pow
(
-
1
,
static_cast
<
float
>
(
data
>>
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
)));
float
d_exp
;
if
(
is_subnormal
<
T
>
(
data
))
d_exp
=
std
::
pow
(
2
,
1
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
else
d_exp
=
std
::
pow
(
2
,
get_exponent_value
<
T
>
(
data
)
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
float
d_mant
=
get_mantissa_value
<
T
>
(
data
);
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_bexp_t
>::
bias
))));
return
data_value
*
scale_value
;
}
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type_sr
(
float
value
,
uint32_t
seed
);
template
<
typename
T
>
inline
T
convert_to_type
(
float
value
)
{
using
bitwise_type
=
typename
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint32_t
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
bitwise_type
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
bitwise_type
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
bitwise_type
mantissa
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
{
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
}
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
const
int
mfmt
=
NumericUtils
<
float
>::
mant
;
uint32_t
x
;
x
=
bit_cast
<
uint32_t
>
(
value
);
uint32_t
head
,
mantissa
;
int32_t
exponent
,
bias
;
uint32_t
sign
;
head
=
x
&
NumericUtils
<
float
>::
head_mask
;
mantissa
=
x
&
NumericUtils
<
float
>::
mant_mask
;
exponent
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
sign
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
bias
=
NumericUtils
<
float
>::
bias
;
if
(
x
==
0
)
{
return
0b0
;
}
const
int
mini_bias
=
NumericUtils
<
T
>::
bias
;
const
int
mini_denormal_act_exponent
=
1
-
mini_bias
;
int
act_exponent
,
out_exponent
,
exponent_diff
;
bool
is_subnorm
=
false
;
if
(
exponent
==
0
)
{
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
mini_denormal_act_exponent
)
{
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
exponent_diff
=
0
;
}
mantissa
+=
(
1UL
<<
mfmt
);
}
auto
shift_amount
=
(
mfmt
-
NumericUtils
<
T
>::
mant
+
exponent_diff
);
shift_amount
=
(
shift_amount
>=
64
)
?
63
:
shift_amount
;
bool
midpoint
=
(
mantissa
&
((
1UL
<<
shift_amount
)
-
1
))
==
(
1UL
<<
(
shift_amount
-
1
));
float
min_subnorm
=
NumericLimits
<
T
>::
DataMinSubnorm
()
*
(
sign
?
-
1
:
1
);
if
(
is_subnorm
&&
std
::
abs
(
value
)
<
std
::
abs
(
min_subnorm
))
{
// closer to 0
if
(
std
::
abs
(
value
)
<=
std
::
abs
(
min_subnorm
-
value
))
return
0
;
else
return
1
|
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
));
}
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
out_exponent
=
(
act_exponent
+
exponent_diff
)
+
mini_bias
-
(
implicit_one
?
0
:
1
);
uint32_t
drop_mask
=
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
))
-
1
;
bool
odd
=
mantissa
&
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
));
mantissa
+=
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
&
drop_mask
;
if
(
out_exponent
==
0
)
{
if
((
1UL
<<
mfmt
)
&
mantissa
)
{
out_exponent
=
1
;
}
}
else
{
if
((
1UL
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
NumericUtils
<
T
>::
mant
);
if
(
out_exponent
==
0
&&
mantissa
==
0
)
{
return
0
;
}
mantissa
&=
(
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
;
return
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
out_exponent
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
template
<
typename
T
>
inline
T
convert_to_type_sr
(
float
value
,
uint32_t
seed
)
{
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
T
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
T
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
double
d_max_value
=
static_cast
<
double
>
(
max_value
);
double
d_actual_max
=
static_cast
<
double
>
(
actual_max
);
double
d_value
=
static_cast
<
double
>
(
value
);
double
d_is
=
std
::
abs
(
d_max_value
-
d_actual_max
);
double
d_seed
=
static_cast
<
double
>
(
seed
);
double
d_prob
=
1.0
f
-
(
std
::
abs
(
d_value
-
d_max_value
)
/
d_is
);
// prob to round down
double
thresh
=
UINT_MAX
*
d_prob
;
if
(
!
get_data_has_inf
<
T
>
()
||
d_seed
<=
thresh
)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return
sign
==
0
?
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
uint32_t
f32
=
bit_cast
<
uint32_t
>
(
value
);
auto
f32_mant
=
f32
&
NumericUtils
<
float
>::
mant_mask
;
auto
head
=
f32
&
NumericUtils
<
float
>::
head_mask
;
auto
f32_exp
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
auto
sign_bit
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
auto
sign
=
sign_bit
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
);
f32_exp
=
static_cast
<
int32_t
>
(
f32_exp
)
-
NumericUtils
<
float
>::
bias
;
int32_t
exp
=
f32_exp
;
auto
mant
=
f32_mant
;
bool
subnorm
=
false
;
if
(
f32
==
0
)
return
0b0
;
if
(
exp
>=
NumericUtils
<
T
>::
unbiased_exp_min
)
{
mant
=
f32_mant
;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else
if
(
exp
<
NumericUtils
<
T
>::
unbiased_exp_min
&&
NumericUtils
<
T
>::
exp
<
NumericUtils
<
float
>::
exp
)
{
subnorm
=
true
;
auto
diff
=
static_cast
<
uint32_t
>
(
NumericUtils
<
T
>::
unbiased_exp_min
-
exp
);
if
(
diff
>=
32
)
{
mant
=
0
;
f32_mant
=
0
;
}
else
{
f32_mant
|=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
;
f32_mant
>>=
diff
;
}
exp
=
0
;
mant
=
f32_mant
;
}
uint32_t
sr_shift
=
NumericUtils
<
T
>::
sr_shift
;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant
+=
seed
>>
sr_shift
;
// Increment exponent when mantissa overflows due to rounding
if
(
mant
>=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
)
++
exp
;
mant
>>=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
auto
biased_exp
=
static_cast
<
uint32_t
>
(
exp
);
if
(
!
subnorm
)
biased_exp
=
static_cast
<
uint32_t
>
(
exp
+
NumericUtils
<
T
>::
bias
);
biased_exp
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
auto
val
=
sign
|
biased_exp
<<
NumericUtils
<
T
>::
mant
|
mant
;
return
val
;
}
}
// namespace ck::utils
include/ck/utility/random_gen.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck/utility/ignore.hpp>
#include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
using
uint8_t
=
unsigned
char
;
using
uint16_t
=
unsigned
short
;
using
uint32_t
=
unsigned
int
;
#endif
namespace
ck
{
// Pseudo random number generator
// version for fp32
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
ck
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
...
...
@@ -23,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
ck
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
...
@@ -40,18 +47,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<!
(
ck
::
is_same
<
float
,
T
>{}
||
ck
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
#ifdef __HIPCC_RTC__
static_cast
<
void
>
(
id
);
static_cast
<
void
>
(
val
);
static_cast
<
void
>
(
seed
);
#else
std
::
ignore
=
id
;
std
::
ignore
=
val
;
std
::
ignore
=
seed
;
#endif
ck
::
ignore
=
id
;
ck
::
ignore
=
val
;
ck
::
ignore
=
seed
;
return
0
;
}
...
...
include/ck/utility/scaled_type_convert.hpp
0 → 100644
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace
ck
{
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
#if CK_USE_OCP_FP8
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#else
__host__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#endif
// convert f8_ocp_t to fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f8_ocp_t
>
(
e8m0_bexp_t
scale
,
f8_ocp_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
f8_ocp_t
>
(
e8m0_bexp_t
scale
,
f8_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8_scaled
<
f8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
data
);
#else
return
type_convert
<
float
>
(
scale
)
*
type_convert
<
float
>
(
x
);
#endif
}
// convert bf8_ocp_t to fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf8_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8_ocp_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
bf8_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
data
);
#else
return
type_convert
<
float
>
(
scale
)
*
type_convert
<
float
>
(
x
);
#endif
}
// convert 2 x f8_ocp_t to 2 x fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x2_ocp_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x2_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2_scaled
<
f8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
f8_ocp_t
>
()[
Number
<
0
>
{}]),
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
f8_ocp_t
>
()[
Number
<
1
>
{}])};
#endif
}
// convert 2 x bf8_ocp_t to 2 x fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x2_ocp_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x2_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2_scaled
<
bf8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
bf8_ocp_t
>
()[
Number
<
0
>
{}]),
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
bf8_ocp_t
>
()[
Number
<
1
>
{}])};
#endif
}
// convert 16 x f8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float16_t
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x16_ocp_t
x
)
#else
inline
__host__
float16_t
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x16_ocp_t
x
)
#endif
{
union
{
f8x16_ocp_t
f8_1x16
;
f8x2_ocp_t
f8_2x8
[
8
];
}
in
{
x
};
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_2x8
[
i
]
=
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
scale
,
in
.
f8_2x8
[
i
]);
});
return
out
.
float_1x16
;
}
// convert 16 x bf8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float16_t
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x16_ocp_t
x
)
#else
inline
__host__
float16_t
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x16_ocp_t
x
)
#endif
{
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
in
{
x
};
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_2x8
[
i
]
=
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
scale
,
in
.
bf8_2x8
[
i
]);
});
return
out
.
float_1x16
;
}
// convert 32 x f8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x32_ocp_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
f8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x32_ocp_t
x
)
#endif
{
union
{
f8x32_ocp_t
f8_1x32
;
f8x16_ocp_t
f8_16x2
[
2
];
}
in
{
x
};
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_16x2
[
i
]
=
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
scale
,
in
.
f8_16x2
[
i
]);
});
return
out
.
float_1x32
;
}
// convert 32 x bf8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x32_ocp_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
bf8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x32_ocp_t
x
)
#endif
{
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
in
{
x
};
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_16x2
[
i
]
=
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
scale
,
in
.
bf8_16x2
[
i
]);
});
return
out
.
float_1x32
;
}
// convert fp32 to fp8
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32 to bf8
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x2 to fp8x2
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#else
inline
__host__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x2 to bf8x2
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#else
inline
__host__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a 6-bit floating-point value (f6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The f6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
scale
,
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a 6-bit floating-point value (bf6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The bf6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
scale
,
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (f6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (f6_t).
*/
template
<
>
inline
__host__
__device__
f6_t
scaled_type_convert
<
f6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template
<
>
inline
__host__
__device__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (bf6_t).
*/
template
<
>
inline
__host__
__device__
bf6_t
scaled_type_convert
<
bf6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template
<
>
inline
__host__
__device__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
}
// namespace ck
include/ck/utility/sequence.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#endif
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
...
...
@@ -903,6 +909,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
}
// namespace ck
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
template
<
ck
::
index_t
...
Is
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
Sequence
<
Is
...
>
)
{
...
...
@@ -914,3 +921,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
return
os
;
}
#endif
#endif
include/ck/utility/static_buffer.hpp
View file @
2a30cfdd
...
...
@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
...
...
@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template
<
typename
...
Ys
,
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
X
>
::
value
&&
!
ck
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Ys
,
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
X
>
::
value
&&
!
ck
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<
ck
::
is_integral
<
Y
>
::
value
||
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
Y
a
,
const
Tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
...
...
@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
// MultiIndex = MultiIndex * scalar
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<
ck
::
is_integral
<
Y
>
::
value
||
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
Y
a
)
{
return
a
*
x
;
...
...
include/ck/utility/tuple.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck/utility/tuple_helper.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#include "functional4.hpp"
#include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp"
#endif
namespace
ck
{
...
...
@@ -158,13 +164,18 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
#ifndef __HIPCC_RTC__
#ifndef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#endif
#endif
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
{
#ifndef CK_CODE_GEN_RTC
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
#endif
}
#endif
...
...
include/ck/utility/type.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef _HIPCC_RTC_
#define CK_CODE_GEN_RTC
#endif
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
#ifdef __HIPCC_RTC__
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
#ifdef CK_CODE_GEN_RTC
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
...
...
@@ -75,7 +73,6 @@ struct remove_reference<T&&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
{
...
...
@@ -107,7 +104,6 @@ constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&&
t_
)
noexcept
{
...
...
@@ -115,17 +111,17 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
}
template
<
class
T
>
struct
is_const
:
false_type
struct
is_const
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
T
>
struct
is_const
<
const
T
>
:
true_type
struct
is_const
<
const
T
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
class
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
class
T
>
template
<
typename
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
class
T
>
...
...
@@ -140,15 +136,13 @@ struct remove_const<const T>
};
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
class
T
>
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
template
<
class
T
>
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
template
<
class
...
>
using
void_t
=
void
;
// template <typename T>
// T&& declval() noexcept;
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
...
...
@@ -159,12 +153,12 @@ T private_declval(long);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
template
<
class
...
>
using
void_t
=
void
;
#else
#include <utility>
#include <type_traits>
using
std
::
declval
;
using
std
::
false_type
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
...
...
@@ -180,9 +174,8 @@ using std::remove_const_t;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
true_type
;
using
std
::
void_t
;
#endif
#endif
template
<
typename
X
,
typename
Y
>
...
...
@@ -195,15 +188,117 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template
<
typename
X
>
struct
is_floating_point
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_floating_point
<
float
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
long
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
>
struct
is_integral
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_integral
<
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
signed
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
wchar_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char16_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char32_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
bool
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
remove_reference_t
<
T
>>
;
...
...
@@ -221,5 +316,4 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
2a30cfdd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
// Define the common macro for
gfx94x
models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// Define the common macro for
MI300
models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
namespace
{
namespace
details
{
[[
maybe_unused
]]
__host__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
half2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
[[
maybe_unused
]]
__device__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
return
amd_assembly_pk_add_f16
(
x
,
y
);
}
}
// namespace details
}
// namespace
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
// Nan check
if
(
x
!=
x
)
{
return
uint16_t
(
0x7FC0
);
}
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
const
uint32_t
first_bf16_mantisa_bit
=
((
u
.
int32
>>
16
)
&
1
);
constexpr
uint32_t
rounding_bias
=
uint32_t
((
1
<<
15
)
-
1
);
return
uint16_t
((
u
.
int32
+
first_bf16_mantisa_bit
+
rounding_bias
)
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
...
...
@@ -51,17 +110,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
}
// convert fp32 to bfp16
// convert fp32 to bfp16
, round to nearest even
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
#if CK_USE_RNE_BF16_CONVERSION
return
bf16_convert_rtn
<
bhalf_t
>
(
x
);
#else
return
uint16_t
(
u
.
int32
>>
16
);
#endif
}
// convert bfp16 to fp16 via fp32
...
...
@@ -100,6 +157,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
template
<
>
inline
__host__
__device__
constexpr
f8_ocp_t
type_convert
<
f8_ocp_t
,
int
>
(
int
x
)
{
return
f8_ocp_t
{
type_convert
<
f8_ocp_t
::
data_type
>
(
x
)};
}
template
<
>
inline
__host__
__device__
constexpr
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
int
>
(
int
x
)
{
return
bf8_ocp_t
{
type_convert
<
bf8_ocp_t
::
data_type
>
(
x
)};
}
// Convert X to Y
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
...
...
@@ -163,10 +232,14 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
long_index_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
union
{
...
...
@@ -189,36 +262,46 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_sr
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
long_index_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
long_index_t
>
(
&
x
),
x
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
union
{
...
...
@@ -240,28 +323,36 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_sr
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_sr
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
long_index_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
...
...
@@ -271,7 +362,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
union
...
...
@@ -296,32 +387,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
cast_to_f8
<
float
,
f8_
fnuz_
t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_
fnuz_
t
f8_convert_rne
<
f8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
f8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
f8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
union
...
...
@@ -345,44 +438,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp16 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_rne
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
fnuz_
t
f8_convert_rne
<
bf8_
fnuz_
t
,
half_t
>
(
half_t
x
)
{
#if defined(__gfx94__)
// convert to float and use native converion
return
f8_convert_rne
<
bf8_t
>
(
type_convert
<
float
>
(
x
));
return
f8_convert_rne
<
bf8_
fnuz_
t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_
fnuz_
t
type_convert
<
f8_
fnuz_
t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
}
// convert fp32 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_ocp_t
>
(
x
);
#endif
}
// convert fp8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
#if defined(__gfx94__)
float
fval
;
...
...
@@ -392,30 +500,95 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_t
>
(
f8x2_t
x
)
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_
fnuz_
t
>
(
f8x2_
fnuz_
t
x
)
{
#if defined(__gfx94__)
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
x
);
return
__builtin_amdgcn_cvt_pk_f32_fp8
(
i16val
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
const
auto
f8x2_v
=
vector_type
<
f8_t
,
2
>
(
x
);
const
auto
f8x2_v
=
vector_type
<
f8_
fnuz_
t
,
2
>
(
x
);
vector_type
<
float
,
2
>
f32x2_v
;
f32x2_v
.
template
AsType
<
float
>()(
Number
<
0
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
0
>
{}]);
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
0
>
{}]);
f32x2_v
.
template
AsType
<
float
>()(
Number
<
1
>
{})
=
utils
::
cast_from_f8
<
f8_t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_t
>()[
Number
<
1
>
{}]);
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
float
,
negative_zero_nan
>
(
f8x2_v
.
template
AsType
<
f8_
fnuz_
t
>()[
Number
<
1
>
{}]);
return
f32x2_v
.
template
AsType
<
float2_t
>()[
Number
<
0
>
{}];
#endif
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f8x2_ocp_t
>
(
f8x2_ocp_t
x
)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2
<
f8_ocp_t
::
default_interpret
>
(
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
x
.
AsType
<
fp8_storage_t
>
()[
Number
<
0
>
{}]),
fp8_impl
::
cast_from_f8
<
float
,
f8_ocp_t
::
wm
,
f8_ocp_t
::
we
,
false
>
(
x
.
AsType
<
fp8_storage_t
>
()[
Number
<
1
>
{}])};
#endif
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t
res
=
{
x_h
,
x_l
};
#elif
float2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#else
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
details
::
pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_h
),
type_convert
<
bhalf_t
>
(
x_l
)};
#else
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_l
),
type_convert
<
bhalf_t
>
(
x_h
)};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
...
...
@@ -428,42 +601,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_t
type_convert
<
f8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
f8_fnuz_t
type_convert
<
f8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
f8_fnuz_t
>
(
x
);
#endif
}
// convert fp16 to fp8
template
<
>
inline
__host__
__device__
f8_ocp_t
type_convert
<
f8_ocp_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
f8_t
>
(
x
);
return
f8_convert_sr
<
f8_
ocp_
t
>
(
x
);
#else
return
f8_convert_rne
<
f8_t
>
(
x
);
return
f8_convert_rne
<
f8_
ocp_
t
>
(
x
);
#endif
}
// convert fp8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
t
>
(
f8
_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_
fnuz_t
>
(
f8_fnuz
_t
x
)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
f8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
f8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_
fnuz_
t
type_convert
<
bf8_
fnuz_
t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
fnuz_
t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
}
// convert fp32 to bf8
template
<
>
inline
__host__
__device__
bf8_ocp_t
type_convert
<
bf8_ocp_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_ocp_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_ocp_t
>
(
x
);
#endif
}
// convert bf8 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
#if defined(__gfx94__)
float
fval
;
...
...
@@ -473,31 +668,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_fnuz_t
,
float
,
negative_zero_nan
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_fnuz_t
type_convert
<
bf8_fnuz_t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_fnuz_t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_fnuz_t
>
(
x
);
#endif
}
// convert fp16 to bf8
template
<
>
inline
__host__
__device__
bf8_t
type_convert
<
bf8_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
bf8_
ocp_
t
type_convert
<
bf8_
ocp_
t
,
half_t
>
(
half_t
x
)
{
#if CK_USE_SR_F8_CONVERSION
return
f8_convert_sr
<
bf8_t
>
(
x
);
return
f8_convert_sr
<
bf8_
ocp_
t
>
(
x
);
#else
return
f8_convert_rne
<
bf8_t
>
(
x
);
return
f8_convert_rne
<
bf8_
ocp_
t
>
(
x
);
#endif
}
// convert bf8 to fp16
template
<
>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_t
>
(
bf8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
bf8_
fnuz_
t
>
(
bf8_
fnuz_
t
x
)
{
#if defined(__gfx94__)
// use native conversion to float and convert to fp16
return
type_convert
<
half_t
>
(
type_convert
<
float
>
(
x
));
#else
constexpr
bool
negative_zero_nan
=
true
;
return
utils
::
cast_from_f8
<
bf8_t
,
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
bf8_
fnuz_
t
,
half_t
,
negative_zero_nan
>
(
x
);
#endif
}
...
...
@@ -512,70 +718,1297 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
}
}
#endif
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
,
x
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
#if defined(__gfx950__)
union
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
// Convert fp32 to bf16 with RTN if higher precision is needed
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{{
x
}};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
float2_array
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{{
0
}};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf
_t
,
float
>
(
float
x
)
inline
__host__
__device__
f4_t
type_convert
<
f4
_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float
scale
=
1.0
f
;
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
return
uint16_t
(
u
.
int32
>>
16
);
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_
t
x
)
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
floa
t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
{
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
#endif
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
}
// namespace ck
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment