Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7572a691
Commit
7572a691
authored
Feb 15, 2025
by
coderfeli
Browse files
merge develop
parents
7796fc73
6b6fcd37
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3999 additions
and
539 deletions
+3999
-539
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+15
-11
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+877
-0
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+5
-1
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+4
-2
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+18
-23
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+8
-8
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+10
-4
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+316
-49
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1416
-72
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+4
-2
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
+210
-0
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+51
-6
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
+293
-10
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+19
-3
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+1
-1
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+11
-1
include/ck_tile/core/numeric/float8.hpp
include/ck_tile/core/numeric/float8.hpp
+593
-340
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+6
-5
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+2
-1
include/ck_tile/core/numeric/pk_int4.hpp
include/ck_tile/core/numeric/pk_int4.hpp
+140
-0
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
include/ck/utility/random_gen.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck/utility/ignore.hpp>
#include "ck/ck.hpp"
#ifdef CK_CODE_GEN_RTC
using
uint8_t
=
unsigned
char
;
using
uint16_t
=
unsigned
short
;
using
uint32_t
=
unsigned
int
;
#endif
namespace
ck
{
// Pseudo random number generator
// version for fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
...
...
@@ -25,7 +30,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
_Float16
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
...
@@ -40,15 +45,14 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
_Float16
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
std
::
ignore
=
id
;
std
::
ignore
=
val
;
std
::
ignore
=
seed
;
ck
::
ignore
=
id
;
ck
::
ignore
=
val
;
ck
::
ignore
=
seed
;
return
0
;
}
...
...
include/ck/utility/scaled_type_convert.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace
ck
{
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
#if CK_USE_OCP_FP8
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#else
__host__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#endif
// convert f8_ocp_t to fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f8_ocp_t
>
(
e8m0_bexp_t
scale
,
f8_ocp_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
f8_ocp_t
>
(
e8m0_bexp_t
scale
,
f8_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8_scaled
<
f8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
data
);
#else
return
type_convert
<
float
>
(
scale
)
*
type_convert
<
float
>
(
x
);
#endif
}
// convert bf8_ocp_t to fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf8_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8_ocp_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
bf8_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
data
);
#else
return
type_convert
<
float
>
(
scale
)
*
type_convert
<
float
>
(
x
);
#endif
}
// convert 2 x f8_ocp_t to 2 x fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x2_ocp_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x2_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2_scaled
<
f8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
f8_ocp_t
>
()[
Number
<
0
>
{}]),
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
f8_ocp_t
>
()[
Number
<
1
>
{}])};
#endif
}
// convert 2 x bf8_ocp_t to 2 x fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x2_ocp_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x2_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2_scaled
<
bf8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
bf8_ocp_t
>
()[
Number
<
0
>
{}]),
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
bf8_ocp_t
>
()[
Number
<
1
>
{}])};
#endif
}
// convert 16 x f8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float16_t
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x16_ocp_t
x
)
#else
inline
__host__
float16_t
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x16_ocp_t
x
)
#endif
{
union
{
f8x16_ocp_t
f8_1x16
;
f8x2_ocp_t
f8_2x8
[
8
];
}
in
{
x
};
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_2x8
[
i
]
=
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
scale
,
in
.
f8_2x8
[
i
]);
});
return
out
.
float_1x16
;
}
// convert 16 x bf8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float16_t
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x16_ocp_t
x
)
#else
inline
__host__
float16_t
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x16_ocp_t
x
)
#endif
{
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
in
{
x
};
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_2x8
[
i
]
=
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
scale
,
in
.
bf8_2x8
[
i
]);
});
return
out
.
float_1x16
;
}
// convert 32 x f8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x32_ocp_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
f8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x32_ocp_t
x
)
#endif
{
union
{
f8x32_ocp_t
f8_1x32
;
f8x16_ocp_t
f8_16x2
[
2
];
}
in
{
x
};
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_16x2
[
i
]
=
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
scale
,
in
.
f8_16x2
[
i
]);
});
return
out
.
float_1x32
;
}
// convert 32 x bf8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x32_ocp_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
bf8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x32_ocp_t
x
)
#endif
{
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
in
{
x
};
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_16x2
[
i
]
=
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
scale
,
in
.
bf8_16x2
[
i
]);
});
return
out
.
float_1x32
;
}
// convert fp32 to fp8
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32 to bf8
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x2 to fp8x2
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#else
inline
__host__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x2 to bf8x2
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#else
inline
__host__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a 6-bit floating-point value (f6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The f6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
scale
,
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a 6-bit floating-point value (bf6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The bf6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
scale
,
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (f6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (f6_t).
*/
template
<
>
inline
__host__
__device__
f6_t
scaled_type_convert
<
f6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template
<
>
inline
__host__
__device__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (bf6_t).
*/
template
<
>
inline
__host__
__device__
bf6_t
scaled_type_convert
<
bf6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template
<
>
inline
__host__
__device__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
}
// namespace ck
include/ck/utility/sequence.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
...
...
@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
template
<
ck
::
index_t
...
Is
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
Sequence
<
Is
...
>
)
{
...
...
@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os
<<
S
::
At
(
S
::
Size
()
-
ck
::
Number
<
1
>
{}).
value
<<
"}"
;
return
os
;
}
#endif
include/ck/utility/static_buffer.hpp
View file @
7572a691
...
...
@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
...
...
@@ -35,10 +35,9 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template
<
typename
...
Ys
,
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
X
>
::
value
&&
!
ck
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -47,10 +46,9 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Ys
,
template
<
typename
...
Ys
,
typename
X
,
enable_if_t
<!
std
::
is_integral
<
X
>
::
value
&&
!
std
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
X
>
::
value
&&
!
ck
::
is_floating_point
<
X
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-=
(
Tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
static_assert
(
X
::
Size
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
...
...
@@ -59,10 +57,9 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return
y
;
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
+
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -73,10 +70,9 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
-
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -87,10 +83,9 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<!
ck
::
is_integral
<
Y
>
::
value
&&
!
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
static_assert
(
Y
::
Size
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
...
...
@@ -104,7 +99,7 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
// MultiIndex = scalar * MultiIndex
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<
ck
::
is_integral
<
Y
>
::
value
||
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
Y
a
,
const
Tuple
<
Xs
...
>&
x
)
{
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
...
...
@@ -117,7 +112,7 @@ __host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
// MultiIndex = MultiIndex * scalar
template
<
typename
...
Xs
,
typename
Y
,
enable_if_t
<
std
::
is_integral
<
Y
>
::
value
||
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
enable_if_t
<
ck
::
is_integral
<
Y
>
::
value
||
ck
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
Y
a
)
{
return
a
*
x
;
...
...
include/ck/utility/tuple.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -32,7 +32,7 @@ struct TupleElementKeyData
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
ck
::
forward
<
T
>
(
v
))
{
}
...
...
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
std
::
forward
(
x
.
mData
);
return
ck
::
forward
(
x
.
mData
);
}
template
<
typename
Indices
,
typename
...
Xs
>
...
...
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Y
>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Ys
>
(
ys
))...
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
...
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
ck
::
forward
<
Y
>
(
y
))
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
ck
::
forward
<
Ys
>
(
ys
)...)
{
}
...
...
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
ck
::
forward
<
Xs
>
(
xs
)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
...
...
include/ck/utility/tuple_helper.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "functional4.hpp"
#include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp"
#endif
namespace
ck
{
...
...
@@ -29,7 +31,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
@@ -38,7 +40,7 @@ template <typename... X, typename... Y>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
@@ -157,13 +159,17 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#endif
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
IsNestedTuple
(
const
Tuple
<
Ts
...
>&
)
{
#ifndef CK_CODE_GEN_RTC
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
#endif
}
template
<
index_t
depth
=
0
,
typename
T
>
...
...
include/ck/utility/type.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template
<
class
T
>
\
struct
name
:
bool_constant
<
__
##
name
(
T
)
>
\
{
\
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template
<
class
T
,
class
U
>
\
struct
name
:
bool_constant
<
__
##
name
(
T
,
U
)
>
\
{
\
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template
<
class
...
Ts
>
\
struct
name
:
bool_constant
<
__
##
name
(
Ts
...)
>
\
{
\
}
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_reference
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
volatile
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
volatile
>
{
typedef
T
type
;
};
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
class
T
>
struct
is_const
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
class
T
>
struct
is_const
<
const
T
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
class
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
class
T
>
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
template
<
class
T
>
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
// template <typename T>
// T&& declval() noexcept;
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
template
<
class
...
>
using
void_t
=
void
;
#else
#include <utility>
#include <type_traits>
using
std
::
declval
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_class_v
;
using
std
::
is_const_v
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_reference_v
;
using
std
::
is_trivially_copyable
;
using
std
::
is_trivially_copyable_v
;
using
std
::
is_unsigned
;
using
std
::
remove_const_t
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
void_t
;
#endif
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
...
...
@@ -19,23 +182,128 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template
<
typename
X
>
struct
is_floating_point
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_floating_point
<
float
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_floating_point
<
long
double
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
>
struct
is_integral
:
public
integral_constant
<
bool
,
false
>
{
};
template
<
>
struct
is_integral
<
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
int
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
short
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
long
long
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
signed
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
unsigned
char
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
wchar_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char16_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
char32_t
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
>
struct
is_integral
<
bool
>
:
public
integral_constant
<
bool
,
true
>
{
};
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
typ
e
;
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
valu
e
;
template
<
typename
T
>
using
remove_
cv
_t
=
typename
std
::
remove_
cv
<
T
>::
type
;
using
remove_
reference
_t
=
typename
remove_
reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_
cv
ref
_t
=
remove_cv_t
<
std
::
remove_reference
_t
<
T
>
>
;
using
remove_ref
erence_t
=
typename
remove_reference
<
T
>
::
type
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
remove_reference_t
<
T
>>
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
...
...
@@ -45,5 +313,4 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
namespace
{
namespace
details
{
[[
maybe_unused
]]
__host__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
half2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
[[
maybe_unused
]]
__device__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
return
amd_assembly_pk_add_f16
(
x
,
y
);
}
}
// namespace details
}
// namespace
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
// Nan check
if
(
x
!=
x
)
{
return
uint16_t
(
0x7FC0
);
}
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
const
uint32_t
first_bf16_mantisa_bit
=
((
u
.
int32
>>
16
)
&
1
);
constexpr
uint32_t
rounding_bias
=
uint32_t
((
1
<<
15
)
-
1
);
return
uint16_t
((
u
.
int32
+
first_bf16_mantisa_bit
+
rounding_bias
)
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -28,13 +87,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
using
NonConstY
=
ck
::
remove_const_t
<
Y
>
;
using
NonConstX
=
ck
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
...
...
@@ -51,17 +110,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
}
// convert fp32 to bfp16
// convert fp32 to bfp16
, round to nearest even
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
#if CK_USE_RNE_BF16_CONVERSION
return
bf16_convert_rtn
<
bhalf_t
>
(
x
);
#else
return
uint16_t
(
u
.
int32
>>
16
);
#endif
}
// convert bfp16 to fp16 via fp32
...
...
@@ -116,7 +173,7 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -178,7 +235,11 @@ template <>
inline
__host__
__device__
f8_fnuz_t
f8_convert_sr
<
f8_fnuz_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
union
{
...
...
@@ -218,7 +279,12 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
f8_fnuz_t
,
negative_zero_nan
,
...
...
@@ -232,7 +298,11 @@ template <>
inline
__host__
__device__
bf8_fnuz_t
f8_convert_sr
<
bf8_fnuz_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
#if defined(__gfx94__)
union
{
...
...
@@ -274,7 +344,12 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
#ifndef CK_CODE_GEN_RTC
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#else
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size_t
>
(
&
x
),
x
);
#endif
return
utils
::
cast_to_f8
<
half_t
,
bf8_fnuz_t
,
negative_zero_nan
,
...
...
@@ -465,6 +540,57 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t
res
=
{
x_h
,
x_l
};
#elif
float2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#else
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
details
::
pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_h
),
type_convert
<
bhalf_t
>
(
x_l
)};
#else
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_l
),
type_convert
<
bhalf_t
>
(
x_h
)};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
...
...
@@ -583,79 +709,1297 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
#if defined(__gfx950__)
union
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
,
x
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
#if defined(__gfx950__)
union
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// Convert fp32 to bf16 with RTN if higher precision is needed
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{{
x
}};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
float2_array
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{{
0
}};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf
_t
,
float
>
(
float
x
)
inline
__host__
__device__
f4_t
type_convert
<
f4
_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float
scale
=
1.0
f
;
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
return
uint16_t
(
u
.
int32
>>
16
);
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_
t
x
)
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
floa
t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
{
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
#endif
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
for
(
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
}
// namespace ck
include/ck_tile/core.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...
...
@@ -31,6 +32,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
...
...
@@ -53,8 +55,8 @@
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
include/ck_tile/core/algorithm/static_encoding_pattern.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace
ck_tile
{
/**
* @brief Enumeration describing static tile distribution patterns.
*
*/
enum
struct
tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked
,
/**
* @brief Warp raked pattern.
*
*/
warp_raked
,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked
,
};
struct
TileDistributionEncodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
,
tile_distribution_pattern
DistributionPattern
>
struct
TileDistributionEncodingPattern2D
:
public
TileDistributionEncodingPattern
{
};
// Thread raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
thread_raked
>
:
public
TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
static
constexpr
index_t
Y1
=
warp_size
/
X0
;
static_assert
(
X0
*
Y1
==
warp_size
,
"X0 * Y1 must cover whole wavefront!"
);
static
constexpr
index_t
Y0
=
num_warps
;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static
constexpr
index_t
Y2
=
YPerTile
/
(
Y1
*
Y0
);
// # of iters within wavefront
static_assert
(
X0
*
Y1
*
Y0
==
BlockSize
,
"X0 * warp_ys * Y0 must cover whole workgroup!"
);
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
{});
}
};
// Warp raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
warp_raked
>
:
public
TileDistributionEncodingPattern
{
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
static
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
static
constexpr
index_t
Y0
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y0
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
static
constexpr
index_t
Y1
=
YPerTile
/
(
Y2
*
Y0
);
// # of iters within wavefront
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
};
// Block raked
template
<
index_t
BlockSize
,
index_t
YPerTile
,
index_t
XPerTile
,
index_t
VecSize
>
struct
TileDistributionEncodingPattern2D
<
BlockSize
,
YPerTile
,
XPerTile
,
VecSize
,
tile_distribution_pattern
::
block_raked
>
:
public
TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert
(
XPerTile
%
VecSize
==
0
,
"XPerTile must be a multiple of VecSize!"
);
static
constexpr
index_t
warp_size
=
get_warp_size
();
static
constexpr
index_t
num_warps
=
BlockSize
/
get_warp_size
();
static
constexpr
index_t
X1
=
VecSize
;
static
constexpr
index_t
X0
=
XPerTile
/
X1
;
// # of threads in X dim
static
constexpr
index_t
Y2
=
warp_size
/
X0
;
// # of rows in Y dim to cover whole wavefront
static_assert
(
X0
*
Y2
==
warp_size
,
"X0 * Y2 must cover whole wavefront!"
);
static
constexpr
index_t
Y1
=
num_warps
;
static_assert
(
X0
*
Y2
*
Y1
==
BlockSize
,
"X0 * Y2 * Y1 must cover whole workgroup!"
);
static
constexpr
index_t
Y0
=
YPerTile
/
(
Y2
*
Y1
);
// # of iters
static_assert
(
Y0
*
Y1
*
Y2
==
YPerTile
,
"Y0, Y1, Y2 must cover whole YPerTile"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Y0
,
Y1
,
Y2
>
,
sequence
<
X0
,
X1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffled2DStaticTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
X0
,
X1
>
,
sequence
<
Y0
,
Y1
,
Y2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
0
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/core/arch/arch.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -12,18 +12,37 @@
namespace
ck_tile
{
enum
struct
address_space_enum
template
<
typename
,
bool
>
struct
safe_underlying_type
;
template
<
typename
T
>
struct
safe_underlying_type
<
T
,
true
>
{
using
type
=
std
::
underlying_type_t
<
T
>
;
};
template
<
typename
T
>
struct
safe_underlying_type
<
T
,
false
>
{
using
type
=
void
;
};
template
<
typename
T
>
using
safe_underlying_type_t
=
typename
safe_underlying_type
<
T
,
std
::
is_enum
<
T
>::
value
>::
type
;
enum
struct
address_space_enum
:
std
::
uint16_t
{
generic
,
generic
=
0
,
global
,
lds
,
sgpr
,
vgpr
,
constant
,
vgpr
};
enum
struct
memory_operation_enum
enum
struct
memory_operation_enum
:
std
::
uint16_t
{
set
,
set
=
0
,
atomic_add
,
atomic_max
,
add
...
...
@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif
}
#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)(
p
);
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
...
...
@@ -8,16 +8,75 @@
namespace
ck_tile
{
CK_TILE_HOST_DEVICE
bf16_t
add_bf16_t
(
const
bf16_t
&
a
,
const
bf16_t
&
b
)
template
<
typename
T
,
typename
ComputeType
>
CK_TILE_HOST_DEVICE
T
add
(
const
T
&
a
,
const
T
&
b
)
{
return
type_convert
<
bf16_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
return
type_convert
<
T
>
(
type_convert
<
ComputeType
>
(
a
)
+
type_convert
<
ComputeType
>
(
b
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
{
bf16x2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf16x4_t
add_bf16x4_t
(
const
bf16x4_t
&
a
,
const
bf16x4_t
&
b
)
{
bf16x4_t
rtn
;
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf16_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf16_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x4_t
add_fp8x4_t
(
const
fp8x4_t
&
a
,
const
fp8x4_t
&
b
)
{
fp8x4_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x8_t
add_fp8x8_t
(
const
fp8x8_t
&
a
,
const
fp8x8_t
&
b
)
{
fp8x8_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
fp8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
fp8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
fp8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
fp8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x4_t
add_bf8x4_t
(
const
bf8x4_t
&
a
,
const
bf8x4_t
&
b
)
{
bf8x4_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x8_t
add_bf8x8_t
(
const
bf8x8_t
&
a
,
const
bf8x8_t
&
b
)
{
bf8x8_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
bf8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
bf8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
bf8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
bf8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
}
...
...
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf16x4_t
>
(
bf16x4_t
*
p_dst
,
bf16x4_t
const
&
x
)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union
U64BF164_ADDR
{
uint64_t
*
u64_a
;
bf16x4_t
*
bf164_a
;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union
U64BF164
{
uint64_t
u64
;
bf16x4_t
bf164
;
};
U64BF164_ADDR
addr
;
addr
.
bf164_a
=
p_dst
;
// interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164
cur_v
;
cur_v
.
u64
=
*
addr
.
u64_a
;
U64BF164
new_v_union
;
uint64_t
old_v
,
new_v
;
do
{
// old 64 bits
old_v
=
cur_v
.
u64
;
// Add elementwise in bf16
new_v_union
.
bf164
=
add_bf16x4_t
(
cur_v
.
bf164
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt the 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x4_t
>
(
fp8x4_t
*
p_dst
,
const
fp8x4_t
&
x
)
{
union
U32FP84_ADDR
{
uint32_t
*
u32_a
;
fp8x4_t
*
fp84_a
;
};
union
U32FP84
{
uint32_t
u32
;
fp8x4_t
fp84
;
};
U32FP84_ADDR
dword_addr
;
U32FP84
cur_v
;
U32FP84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
fp84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
fp84
=
add_fp8x4_t
(
cur_v
.
fp84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x4_t
>
(
bf8x4_t
*
p_dst
,
const
bf8x4_t
&
x
)
{
union
U32BF84_ADDR
{
uint32_t
*
u32_a
;
bf8x4_t
*
bf84_a
;
};
union
U32BF84
{
uint32_t
u32
;
bf8x4_t
bf84
;
};
U32BF84_ADDR
dword_addr
;
U32BF84
cur_v
;
U32BF84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf84
=
add_bf8x4_t
(
cur_v
.
bf84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
//
// Atomic add for fp8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x8_t
>
(
fp8x8_t
*
p_dst
,
fp8x8_t
const
&
x
)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union
U64FP88_ADDR
{
uint64_t
*
u64_a
;
// pointer to 64-bit integer
fp8x8_t
*
fp88_a
;
// pointer to fp8x8_t
};
union
U64FP88
{
uint64_t
u64
;
fp8x8_t
fp88
;
};
U64FP88_ADDR
dword_addr
;
U64FP88
cur_v
;
U64FP88
new_v_union
;
uint64_t
old_v
,
new_v
;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr
.
fp88_a
=
p_dst
;
// Initial read of 64 bits from memory
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union
.
fp88
=
add_fp8x8_t
(
cur_v
.
fp88
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
//
// Atomic add for bf8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x8_t
>
(
bf8x8_t
*
p_dst
,
bf8x8_t
const
&
x
)
{
union
U64BF88_ADDR
{
uint64_t
*
u64_a
;
bf8x8_t
*
bf88_a
;
};
union
U64BF88
{
uint64_t
u64
;
bf8x8_t
bf88
;
};
U64BF88_ADDR
dword_addr
;
U64BF88
cur_v
;
U64BF88
new_v_union
;
uint64_t
old_v
,
new_v
;
dword_addr
.
bf88_a
=
p_dst
;
// Read the original 64 bits
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union
.
bf88
=
add_bf8x8_t
(
cur_v
.
bf88
,
x
);
new_v
=
new_v_union
.
u64
;
// 64-bit CAS loop
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
{
...
...
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
)),
"The granularity of the thread buffer is unsupported on the hardware!"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
...
...
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
}
else
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x2_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x2_t
>()[
I1
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
}
else
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x4_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
fp8x8_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf8x8_t
>()[
I1
]);
}
}
}
...
...
include/ck_tile/core/config.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...
...
@@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
...
...
@@ -230,3 +234,15 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
include/ck_tile/core/container/tuple.hpp
View file @
7572a691
...
...
@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
using
Idx
=
number
<
tuple
<
Ts
...
>::
size
()
-
i
-
1
>
;
return
t
.
at
(
Idx
{});
},
number
<
tuple
<
Ts
...
>::
size
()
()
>
{});
number
<
tuple
<
Ts
...
>::
size
()
>
{});
}
// Reduce tuple values in specific range using Function
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -376,6 +376,16 @@ struct numeric<bfloat16_t>
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
bfloat16_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
bfloat16_t
)
#endif
...
...
include/ck_tile/core/numeric/float8.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -14,6 +14,12 @@
#pragma once
#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
#define CK_TILE_FP8_CVT_DEVICE 0
#endif
namespace
ck_tile
{
// fp8 rounding modes
...
...
@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
stochastic
};
/**
* \brief FP8 interpretation used in conversion algorithms
*/
enum
class
fp8_interpretation
{
E4M3_OCP
=
0
,
// OCP FP8 E4M3
E5M2_OCP
=
1
,
// OCP BF8 E5M2
E4M3_FNUZ
=
2
,
// FNUZ FP8 E4M3
E5M2_FNUZ
=
3
,
// FNUZ BF8 E5M2
};
/*
* ______________
NANOO
_________________ | ______________
IEEE
________________
* ______________
FNUZ
_________________ | ______________
OCP
________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
(448)
s.00000.11
(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111
s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
...
...
@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
{
static
constexpr
int
exponent
=
4
;
static
constexpr
int
mantissa
=
3
;
#if
defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
7
;
// OCP
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
// IEEE
static
constexpr
int
bias
=
8
;
// FNUZ
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
...
...
@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
{
static
constexpr
int
exponent
=
5
;
static
constexpr
int
mantissa
=
2
;
#if
defined(__gfx94__)
static
constexpr
int
bias
=
1
<<
(
exponent
-
1
);
// NANOO
#if
CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
1
5
;
// OCP
#else
static
constexpr
int
bias
=
(
1
<<
(
exponent
-
1
))
-
1
;
//
IEEE
static
constexpr
int
bias
=
1
6
;
//
FNUZ
#endif
using
raw_type
=
uint8_t
;
raw_type
data
;
...
...
@@ -183,501 +200,727 @@ struct native_t<bf8_t>
};
#else
using
fp8_t
=
_BitInt
(
8
);
using
fp8_raw_t
=
uint8_t
;
using
bf8_t
=
unsigned
_BitInt
(
8
);
using
bf8_raw_t
=
uint8_t
;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
T
>
struct
numeric_traits
;
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
template
<
>
struct
numeric_traits
<
fp8_t
>
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
using
bitwise_type
=
fp8_raw_t
;
// original type exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
7
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_OCP
;
#else
static
constexpr
int
bias
=
8
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E4M3_FNUZ
;
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
int
exponent
,
bias
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr
Y
nan_code
=
numeric
<
Y
>::
quiet_NaN
();
// __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
template
<
>
struct
numeric_traits
<
bf8_t
>
{
using
bitwise_type
=
bf8_raw_t
;
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if CK_TILE_USE_OCP_FP8
static
constexpr
int
bias
=
15
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_OCP
;
#else
constexpr
Y
nan_code
=
0x80
;
static
constexpr
int
bias
=
16
;
static
constexpr
fp8_interpretation
f8_interpret
=
fp8_interpretation
::
E5M2_FNUZ
;
#endif
static
constexpr
uint8_t
abs_mask
=
0x7F
;
};
constexpr
uint32_t
nan_mask
=
numeric_traits
<
X
>::
nan_mask
;
// below is sw fp8 conversion, not utilizing hw instruction
namespace
impl
{
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
,
bool
stoch
=
false
>
CK_TILE_HOST_DEVICE
DstT
run_cast_to_f8
(
SrcT
src
,
unsigned
int
rng
=
0
)
{
static_assert
(
std
::
is_same
<
DstT
,
fp8_t
>::
value
||
std
::
is_same
<
DstT
,
bf8_t
>::
value
,
"DstT type must be fp8 or bf8."
);
// convert to bitwise
using
T_bitwise
=
typename
numeric_traits
<
X
>::
bitwise_typ
e
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
)
);
constexpr
bool
is_half
=
std
::
is_same
<
SrcT
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
SrcT
,
float
>::
valu
e
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
//
unpack the input, depends on datatype
head
=
x_bitwise
&
numeric_traits
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
numeric_traits
<
X
>::
mant
_mask
;
exponent
=
(
head
>>
in_mant
)
&
numeric_traits
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
bias
=
numeric_traits
<
X
>::
bias
;
//
fp8/bf8 type exponent/mantissa layout
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
bool
is_fnuz
=
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
numeric_traits
<
DstT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
)
;
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
if
constexpr
(
negative_zero_nan
)
using
SrcT_bitwise
=
typename
numeric_traits
<
SrcT
>::
bitwise_type
;
SrcT_bitwise
src_bitwise
=
bit_cast
<
SrcT_bitwise
>
(
src
);
unsigned
long
long
head
,
mantissa
;
int
exponent
,
bias
;
unsigned
int
sign
;
unsigned
long
long
fInf
,
abs_mask
;
head
=
src_bitwise
&
numeric_traits
<
SrcT
>::
head_mask
;
mantissa
=
src_bitwise
&
numeric_traits
<
SrcT
>::
mant_mask
;
exponent
=
(
head
>>
SrcT_mant
)
&
numeric_traits
<
SrcT
>::
exp_mask
;
sign
=
head
>>
(
SrcT_exp
+
SrcT_mant
);
bias
=
numeric_traits
<
SrcT
>::
bias
;
fInf
=
numeric_traits
<
SrcT
>::
Inf
;
abs_mask
=
numeric_traits
<
SrcT
>::
abs_mask
;
unsigned
int
signed_inf
=
0
;
unsigned
int
nan
=
0
;
if
constexpr
(
is_fnuz
)
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
nan_code
;
signed_inf
=
clip
?
((
sign
<<
7
)
+
0x7f
)
:
0x80
;
nan
=
0x80
;
}
else
{
if
((
x_bitwise
&
nan_mask
)
==
nan_mask
)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
if
constexpr
(
DstT_exp
==
4
)
{
// e4m3
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
}
else
{
// e5m2
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7b
:
0x7c
);
}
nan
=
(
sign
<<
7
)
+
0x7f
;
}
// Max values
unsigned
long
long
ifmax
=
0
;
if
constexpr
(
is_float
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x47600000
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x43700000
;
}
else
{
ifmax
=
0x43E00000
;
}
}
}
else
if
constexpr
(
is_half
)
{
if
constexpr
(
DstT_exp
==
5
)
{
ifmax
=
0x7B00
;
}
else
{
if
constexpr
(
is_fnuz
)
{
ifmax
=
0x5B80
;
}
else
{
ifmax
=
0x5F00
;
}
}
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
0
));
// Deal with inf and NaNs
if
((
src_bitwise
&
fInf
)
==
fInf
)
{
if
constexpr
(
is_fnuz
)
return
signed_inf
;
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
return
mantissa
!=
0
?
nan
:
signed_inf
;
}
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const
int
out_bias
=
(
1
<<
(
out_exp
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
out_denormal_act_exponent
=
1
-
out_bias
;
// actual exponent of f8 denormal
if
((
src_bitwise
&
abs_mask
)
>
ifmax
)
{
return
signed_inf
;
}
if
(
src_bitwise
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
DstT_exp
-
1
))
-
1
+
(
is_fnuz
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
//
out
_exponent is the converted f8 exponent with bias encoding
//
f8
_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
out
_exponent
,
exponent_diff
;
int
act_exponent
,
f8
_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
out
_denormal_act_exponent
-
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
out
_denormal_act_exponent
)
if
(
act_exponent
<=
f8
_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implic
i
t 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
out
_denormal_act_exponent
-
act_exponent
;
exponent_diff
=
f8
_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just
that it does not need shift mantissa
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference
// for this case, act_exponent could be larger. Just
//
that it does not need shift mantissa
}
mantissa
+=
(
1
<<
in
_mant
);
// Add the implicit 1 into mantissa
mantissa
+=
(
1
ull
<<
SrcT
_mant
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
in_mant
-
out_mant
+
exponent_diff
))
-
1
))
==
(
1
<<
(
in_mant
-
out_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
bool
midpoint
=
(
mantissa
&
((
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
))
-
1
))
==
(
1ull
<<
(
SrcT_mant
-
DstT_mant
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
in_mant
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
out_bias
-
(
implicit_one
?
0
:
1
);
bool
implicit_one
=
mantissa
&
(
1ull
<<
SrcT_mant
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
unsigned
long
long
drop_mask
=
(
1ull
<<
(
SrcT_mant
-
DstT_mant
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
in_mant
-
out_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
mantissa
&
(
1ull
<<
(
SrcT_mant
-
DstT_mant
));
// if the least significant bit that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1ull
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
out
_exponent
==
0
)
if
(
f8
_exponent
==
0
)
{
if
((
1
<<
in
_mant
)
&
mantissa
)
if
((
1
ull
<<
SrcT
_mant
)
&
mantissa
)
{
out_exponent
=
1
;
// denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
else
{
if
((
1
<<
(
in
_mant
+
1
))
&
mantissa
)
if
((
1
ull
<<
(
SrcT
_mant
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
// No need to make 1 implicit now as it will be addressed later
f8_exponent
++
;
}
}
mantissa
>>=
(
in
_mant
-
out
_mant
);
mantissa
>>=
(
SrcT
_mant
-
DstT
_mant
);
if
(
out_exponent
>
max_exp
)
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
DstT_exp
)
-
1
;
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
if
constexpr
(
clip
)
{
mantissa
=
(
1
<<
out
_mant
)
-
1
;
out
_exponent
=
max_exp
;
mantissa
=
(
1
<<
DstT
_mant
)
-
1
;
f8
_exponent
=
max_exp
;
}
else
{
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
signed_inf
))
;
return
signed_inf
;
}
}
// check if x is 0.0 or -0.0
if
(
out_exponent
==
0
&&
mantissa
==
0
)
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
(
negative_zero_nan
?
0
:
(
sign
<<
(
out_exp
+
out_mant
))));
mantissa
&=
(
1
<<
out_mant
)
-
1
;
return
__builtin_bit_cast
(
Y
,
static_cast
<
uint8_t
>
((
sign
<<
(
out_exp
+
out_mant
))
|
(
out_exponent
<<
out_mant
)
|
mantissa
));
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
return
is_fnuz
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
DstT_mant
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
DstT_mant
)
|
mantissa
;
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
run_cast_from_f8
(
X
x
)
template
<
typename
SrcT
,
typename
DstT
,
bool
clip
=
true
>
CK_TILE_HOST_DEVICE
DstT
run_cast_from_f8
(
SrcT
x
)
{
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
numeric_traits
<
X
>::
exp
;
constexpr
int
in_mant
=
numeric_traits
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
out_exp
=
numeric_traits
<
Y
>::
exp
;
constexpr
int
out_mant
=
numeric_traits
<
Y
>::
mant
;
uint8_t
x_raw
=
__builtin_bit_cast
(
uint8_t
,
x
);
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
numeric_traits
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
numeric_traits
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
numeric_traits
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
numeric_traits
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
numeric_traits
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x_raw
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x_raw
>>
(
in_exp
+
in_mant
);
uint32_t
mantissa
=
x_raw
&
((
1
<<
in_mant
)
-
1
);
int
exponent
=
(
x_raw
&
0x7F
)
>>
in_mant
;
static_assert
(
std
::
is_same
<
SrcT
,
fp8_t
>::
value
||
std
::
is_same
<
SrcT
,
bf8_t
>::
value
,
"SrcT type must be fp8 or bf8."
);
constexpr
int
SrcT_exp
=
numeric_traits
<
SrcT
>::
exp
;
constexpr
int
SrcT_mant
=
numeric_traits
<
SrcT
>::
mant
;
constexpr
bool
is_fnuz
=
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
numeric_traits
<
SrcT
>::
f8_interpret
==
fp8_interpretation
::
E5M2_FNUZ
);
constexpr
bool
is_half
=
std
::
is_same
<
DstT
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
DstT
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"DstT type must be half_t or float."
);
// destination type exponent/mantissa layout
constexpr
int
DstT_exp
=
numeric_traits
<
DstT
>::
exp
;
// exponent width of the destination type
constexpr
int
DstT_mant
=
numeric_traits
<
DstT
>::
mant
;
// mantissa width of the destination type
constexpr
DstT
fInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Inf
);
constexpr
DstT
fNegInf
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NegInf
);
constexpr
DstT
fNaN
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
NaN
);
constexpr
DstT
fNeg0
=
bit_cast
<
DstT
>
(
numeric_traits
<
DstT
>::
Neg0
);
DstT
fmax
{
0
},
fmin
{
0
};
// Max number in e5m2 57344
if
constexpr
(
is_half
)
{
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x7B00
));
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xFB00
));
}
else
if
constexpr
(
is_float
)
{
fmax
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0x47600000
));
fmin
=
bit_cast
<
DstT
>
(
static_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
0xC7600000
));
}
constexpr
int
exp_low_cutoff
=
(
1
<<
(
out_exp
-
1
))
-
(
1
<<
(
in_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwise
retval
;
if
(
x
==
0
)
{
return
0
;
}
if
constexpr
(
negative_zero_nan
)
unsigned
long
long
sign
=
x
>>
7
;
unsigned
long
long
mantissa
=
x
&
((
1
<<
SrcT_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
SrcT_mant
;
if
constexpr
(
is_fnuz
)
{
if
(
x
==
0x80
)
{
if
(
x_raw
==
nan_code
)
return
NaN
;
return
fNaN
;
}
}
else
{
if
(
x_raw
==
nan_code
)
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
if
(
x
==
0x80
)
{
return
fNeg0
;
}
if
constexpr
(
SrcT_exp
==
4
)
{
// e4m3
if
((
x
&
0x7F
)
==
0x7F
)
{
return
fNaN
;
}
}
else
if
((
x
&
0x7C
)
==
0x7C
)
{
// e5m2
if
((
x
&
0x3
)
==
0
)
{
if
constexpr
(
clip
)
{
return
sign
?
fmin
:
fmax
;
}
return
sign
?
fNegInf
:
fInf
;
}
return
fNaN
;
}
}
typename
numeric_traits
<
DstT
>::
bitwise_type
retval
;
if
((
numeric_traits
<
Y
>::
mant
==
10
)
&&
(
numeric_traits
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
if
constexpr
(
SrcT_exp
==
5
&&
is_half
&&
!
is_fnuz
)
{
retval
=
x_raw
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
retval
=
x
<<
8
;
return
bit_cast
<
DstT
>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
DstT_exp
-
1
))
-
(
1
<<
(
SrcT_exp
-
1
))
+
1
-
(
is_fnuz
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
in_mant
);
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
SrcT_mant
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
in
_mant
)
-
1
);
mantissa
&=
((
1
ull
<<
SrcT
_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
out
_mant
-
in
_mant
;
mantissa
<<=
DstT
_mant
-
SrcT
_mant
;
// subnormal output (occurs when
T=
half, we=5,
negative_zero_nan
=true)
// subnormal output (occurs when
DstT is
half
_t
, we=5,
is_fnuz
=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
out
_mant
;
mantissa
|=
1
<<
DstT
_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatypes
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
retval
=
(
sign
<<
(
DstT_exp
+
DstT_mant
))
|
(
exponent
<<
DstT_mant
)
|
mantissa
;
return
run
_cast
_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
bit
_cast
<
DstT
>
(
retval
);
}
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
CK_TILE_HOST_DEVICE
Y
cast_
from
_f8
(
X
x
)
template
<
typename
X
,
typename
Y
,
bool
clip
,
bool
stoch
>
CK_TILE_HOST_DEVICE
Y
cast_
to
_f8
(
X
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
return
bit_cast
<
Y
>
(
run_cast_to_f8
<
X
,
Y
,
clip
,
stoch
>
(
x
,
rng
));
}
}
// namespace impl
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_sr_raw
(
float
x
)
#if CK_TILE_FP8_CVT_DEVICE
/**
* @brief Cast float to fp8/bf8 using device conversion instructions
*/
template
<
fp8_interpretation
interpret
,
bool
saturate
,
bool
stochastic_rounding
=
false
>
CK_TILE_DEVICE
uint8_t
cast_to_f8_from_f32
(
float
v
,
unsigned
int
rng
=
0
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
uint8_t
i8data
;
union
{
float
fval
;
u
int32_
t
i32val
;
u
int8_t
i8val
[
4
];
// not endian independent
u
nsigned
in
t
i32val
;
u
nsigned
char
i8val
[
4
];
//
NOTE:
not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
unsigned
int
ival
=
0
;
val
.
fval
=
v
;
if
constexpr
(
saturate
)
{
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
}
else
if
constexpr
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
{
// OCP type
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
448.0
,
-
448.0
);
}
}
else
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
}
}
if
constexpr
(
stochastic_rounding
)
{
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
)
:
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
i8data
=
val
.
i8val
[
0
];
// little endian
}
else
{
// RNE CVT
ival
=
(
interpret
==
fp8_interpretation
::
E4M3_FNUZ
)
||
(
interpret
==
fp8_interpretation
::
E4M3_OCP
)
?
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
)
:
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
}
return
i8data
;
}
#endif // CK_TILE_FP8_CVT_DEVICE
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_sr_raw
(
float
x
)
}
// namespace impl
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic
* rounding.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may
* involve clipping and uses a pseudo-random number generator for the stochastic rounding.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_sr_raw
(
SrcT
x
)
{
constexpr
bool
clip
=
true
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator_t
<
float
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
uint32_t
rng
=
prand_generator_t
<
SrcT
,
seed
>
{}(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if CK_TILE_FP8_CVT_DEVICE
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
true
>
(
x
,
rng
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
stochastic
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
true
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_rtn_raw
(
float
x
)
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to
* nearest even.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may involve clipping.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template
<
typename
SrcT
,
typename
DstT
>
CK_TILE_HOST_DEVICE
typename
numeric_traits
<
DstT
>::
bitwise_type
float_to_fp8_rtn_raw
(
SrcT
x
)
{
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
fp8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
fp8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
#endif
}
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_rtn_raw
(
float
x
)
{
#if defined(__gfx94__)
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
#if CK_TILE_FP8_CVT_DEVICE
return
impl
::
cast_to_f8_from_f32
<
numeric_traits
<
DstT
>::
f8_interpret
,
clip
,
false
>
(
x
,
0
);
#else
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
fp8_rounding_mode
rm
=
fp8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
return
bit_cast
<
bf8_raw_t
>
(
impl
::
cast_to_f8
<
float
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
fp8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
return
bit_cast
<
typename
numeric_traits
<
DstT
>::
bitwise_type
>
(
impl
::
cast_to_f8
<
SrcT
,
DstT
,
clip
,
false
>
(
x
,
0
));
#endif
}
// clang-format off
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
fp8_raw_t
float_to_fp8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_fp8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_fp8_sr_raw
(
x
);
else
return
fp8_raw_t
{
0
};
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
{
return
float_to_fp8_rtn_raw
<
float
,
fp8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
fp8_t
>
(
x
);
}
else
{
return
fp8_raw_t
{
0
};
}
}
template
<
fp8_rounding_mode
rounding
>
template
<
fp8_rounding_mode
rounding
>
CK_TILE_HOST_DEVICE
bf8_raw_t
float_to_bf8_raw
(
float
x
,
constant
<
rounding
>
)
{
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
return
float_to_bf8_rtn_raw
(
x
);
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
return
float_to_bf8_sr_raw
(
x
);
else
return
bf8_raw_t
{
0
};
if
constexpr
(
rounding
==
fp8_rounding_mode
::
standard
)
{
return
float_to_fp8_rtn_raw
<
float
,
bf8_t
>
(
x
);
}
else
if
constexpr
(
rounding
==
fp8_rounding_mode
::
stochastic
)
{
return
float_to_fp8_sr_raw
<
float
,
bf8_t
>
(
x
);
}
else
{
return
bf8_raw_t
{
0
};
}
}
CK_TILE_HOST_DEVICE
float
fp8_to_float_raw
(
fp8_raw_t
x
)
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_fp8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
fp8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
fp8_t
>
(
x
));
return
impl
::
run_cast_from_f8
<
fp8_t
,
float
>
(
bit_cast
<
fp8_t
>
(
x
));
#endif
}
CK_TILE_HOST_DEVICE
float
bf8_to_float_raw
(
bf8_raw_t
x
)
{
#if
defined(__gfx94__)
#if
CK_TILE_FP8_CVT_DEVICE
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
x
);
fval
=
__builtin_amdgcn_cvt_f32_bf8
(
i32val
,
0
);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return
fval
;
#else
constexpr
bool
negative_zero_nan
=
true
;
return
impl
::
cast_from_f8
<
bf8_t
,
float
,
negative_zero_nan
>
(
bit_cast
<
bf8_t
>
(
x
));
return
impl
::
run_cast_from_f8
<
bf8_t
,
float
>
(
bit_cast
<
bf8_t
>
(
x
));
#endif
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
fp8_t
float_to_fp8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
fp8_t
>
(
float_to_fp8_raw
(
x
,
constant
<
rounding
>
{}));
}
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
template
<
fp8_rounding_mode
rounding
=
static_cast
<
fp8_rounding_mode
>(
CK_TILE_FLOAT_TO_FP8_DEFAULT
)
>
CK_TILE_HOST_DEVICE
bf8_t
float_to_bf8
(
float
x
,
constant
<
rounding
>
=
{})
{
return
bit_cast
<
bf8_t
>
(
float_to_bf8_raw
(
x
,
constant
<
rounding
>
{}));
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
fp8_to_float
(
fp8_t
x
)
{
return
fp8_to_float_raw
(
bit_cast
<
fp8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
CK_TILE_HOST_DEVICE
float
bf8_to_float
(
bf8_t
x
)
{
return
bf8_to_float_raw
(
bit_cast
<
bf8_raw_t
>
(
x
));
}
// clang-format on
template
<
typename
T
>
struct
numeric_traits
;
template
<
class
T
>
struct
numeric
;
#if CK_TILE_USE_OCP_FP8
template
<
>
struct
numeric
_traits
<
fp8_t
>
struct
numeric
<
fp8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
#if defined(__gfx94__)
static
constexpr
int
bias
=
8
;
#else
static
constexpr
int
bias
=
7
;
#endif
// minimum finite value, or minimum positive normal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x08
));
// 0b00001000 = 2^-6
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
lowest
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xfe
));
// 0b11111110 = -448
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
max
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7e
));
// 0b01111110 = 448
}
// difference between 1.0 and next representable f8 value (1.125)
// returns fp8_t(0.125)
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
epsilon
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x20
));
// 0.125
}
// rounding error (0.0625)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
round_error
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x18
));
// 0.0625
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
quiet_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
signaling_NaN
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0xFF
));
// 0b11111111
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
denorm_min
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
fp8_t
zero
()
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
0
));
}
};
template
<
>
struct
numeric
_traits
<
bf8_t
>
struct
numeric
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
#if defined(__gfx94__)
static
constexpr
int
bias
=
16
;
#else
static
constexpr
int
bias
=
15
;
// IEEE
#endif
};
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x04
));
// 0b00000100 = 2^-14
}
template
<
class
T
>
struct
numeric
;
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
lowest
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xfb
));
// 0b11111011 = -57344
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
max
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7b
));
// 0b01111011 = 57344
}
// difference between 1.0 and next representable bf8 value (1.25)
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
epsilon
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x34
));
// 0.25
}
// rounding error (0.125)
// half of epsilon
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
round_error
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x30
));
// 0.125
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
infinity
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7c
));
// 0b01111100
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
quiet_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x7F
));
// 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
signaling_NaN
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0xFF
));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
denorm_min
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0x01
));
}
CK_TILE_HOST_DEVICE
static
constexpr
bf8_t
zero
()
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#else
template
<
>
struct
numeric
<
fp8_t
>
{
...
...
@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
return
bit_cast
<
bf8_t
>
(
static_cast
<
bf8_raw_t
>
(
0
));
}
};
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT
(
CK_TILE_HOST_DEVICE
,
fp8_t
)
...
...
@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t
abs
(
const
fp8_t
&
x
)
template
<
typename
T
>
CK_TILE_HOST_DEVICE
T
abs
(
const
T
&
x
)
{
return
bit_cast
<
fp8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
fp8_raw_t
>
(
x
)
&
0x7f
));
static_assert
(
std
::
is_same_v
<
T
,
fp8_t
>
||
std
::
is_same_v
<
T
,
bf8_t
>
,
"Only fp8_t and bf8_t are supported"
);
return
bit_cast
<
T
>
(
static_cast
<
uint8_t
>
(
bit_cast
<
uint8_t
>
(
x
)
&
numeric_traits
<
T
>::
abs_mask
));
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
fp8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
fp8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
}
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
==
0x7f
;
#else
return
xx
==
0x80
;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
fp8_t
sqrt
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
fp8_t
log
(
fp8_t
x
)
{
return
static_cast
<
fp8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST_DEVICE
bf8_t
abs
(
const
bf8_t
&
x
)
{
return
bit_cast
<
bf8_t
>
(
static_cast
<
fp8_raw_t
>
(
bit_cast
<
bf8_raw_t
>
(
x
)
&
0x7f
));
}
#endif
CK_TILE_HOST_DEVICE
bool
isnan
(
const
bf8_t
&
x
)
{
uint8_t
xx
=
bit_cast
<
bf8_raw_t
>
(
x
);
return
xx
==
0x80
;
// TODO: NANOO
#if CK_TILE_USE_OCP_FP8
return
(
xx
&
0x7f
)
>
0x7c
;
#else
return
xx
==
0x80
;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
bf8_t
sqrt
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
...
...
@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
bf8_t
log
(
bf8_t
x
)
{
return
static_cast
<
bf8_t
>
(
__logf
(
static_cast
<
float
>
(
x
)));
};
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/half.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
...
...
@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
uint16_t
abs_mask
=
0x7FFF
;
static
constexpr
uint16_t
Inf
=
0x7C00
;
static
constexpr
uint16_t
NegInf
=
0xFC00
;
static
constexpr
uint16_t
NaN
=
0x7C01
;
static
constexpr
uint16_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
...
...
include/ck_tile/core/numeric/numeric.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -89,6 +89,7 @@ struct numeric_traits<float>
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
abs_mask
=
0x7FFFFFFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
...
...
include/ck_tile/core/numeric/pk_int4.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace
ck_tile
{
// Packed 2xint4
struct
pk_int4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_int4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_int4_t
(
type
init
)
:
data
{
init
}
{}
};
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
pk_int4_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
min
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
lowest
()
{
constexpr
uint8_t
val
=
0b10001000
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
max
()
{
constexpr
uint8_t
val
=
0b01110111
;
return
pk_int4_t
(
bit_cast
<
int8_t
>
(
val
));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
pk_int4_t
zero
()
{
return
0
;
}
};
CK_TILE_HOST_DEVICE
fp32x2_t
pk_int4_t_to_fp32x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t
res
=
{
x_h
,
x_l
};
#elif
fp32x2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
CK_TILE_HOST_DEVICE
fp16x2_t
pk_int4_t_to_halfx2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#elif
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
pk_add_f16
(
bit_cast
<
fp16x2_t
>
(
lo
),
bit_cast
<
fp16x2_t
>
(
SUB
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
pk_int4_t_to_bfloat16x2_t
(
const
pk_int4_t
&
x
)
{
uint8_t
x_u8
=
ck_tile
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_h
),
type_convert
<
bf16_t
>
(
x_l
)};
#elif
bf16x2_t
res
=
{
type_convert
<
bf16_t
>
(
x_l
),
type_convert
<
bf16_t
>
(
x_h
)};
#endif
return
res
;
}
}
// namespace ck_tile
Prev
1
…
13
14
15
16
17
18
19
20
21
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment