Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ba504b6
Commit
9ba504b6
authored
Feb 07, 2025
by
ThomasNing
Browse files
merge with the develop support the fp8 with computev4
parents
e3402c93
f49de496
Changes
198
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5218 additions
and
70 deletions
+5218
-70
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+12
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
..._batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+312
-11
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+1
-1
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+14
-28
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+264
-1
include/ck/utility/blkgemmpipe_scheduler.hpp
include/ck/utility/blkgemmpipe_scheduler.hpp
+10
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+619
-6
include/ck/utility/e8m0.hpp
include/ck/utility/e8m0.hpp
+80
-0
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+109
-0
include/ck/utility/mxf6_utils.hpp
include/ck/utility/mxf6_utils.hpp
+325
-0
include/ck/utility/mxf8_utils.hpp
include/ck/utility/mxf8_utils.hpp
+570
-0
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+384
-0
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+877
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+1340
-6
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
+293
-10
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+4
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
9ba504b6
...
...
@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
9ba504b6
...
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
;
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
9ba504b6
...
...
@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
9ba504b6
...
...
@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
9ba504b6
...
...
@@ -37,7 +37,17 @@ enum struct MfmaInstr
mfma_f32_32x32x16f8bf8
,
mfma_f32_16x16x32f8bf8
,
mfma_f32_32x32x16bf8f8
,
mfma_f32_16x16x32bf8f8
mfma_f32_16x16x32bf8f8
,
mfma_f32_32x32x16f16
,
mfma_f32_16x16x32f16
,
mfma_f32_32x32x16bf16
,
mfma_f32_16x16x32bf16
,
mfma_i32_32x32x32i8
,
mfma_i32_16x16x64i8
,
mfma_f32_32x32x64f8f6f4
,
mfma_f32_16x16x128f8f6f4
,
mfma_scale_f32_32x32x64f8f6f4
,
mfma_scale_f32_16x16x128f8f6f4
};
template
<
MfmaInstr
instr
>
...
...
@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16f16
>
{
...
...
@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
>
{
...
...
@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32bf16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
>
{
...
...
@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x64i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x64i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
...
...
@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
}
};
// TODO: fix mfma...f8f6f4 instructions
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_32x32x64f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
>
{
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_scale_f32_16x16x128f8f6f4
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
typename
additional_type
=
base_type
,
bool
is_single_rate_mfma
=
false
>
struct
MfmaSelector
{
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
typename
additional_type_
=
base_type_
,
bool
is_single_rate_mfma_
=
false
>
static
constexpr
auto
GetMfma
();
template
<
>
...
...
@@ -711,13 +952,32 @@ struct MfmaSelector
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16f16
;
#else
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
true
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32f16
;
#else
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
true
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
...
...
@@ -741,7 +1001,19 @@ struct MfmaSelector
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16bf16
;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_32x32x4bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
true
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
...
@@ -751,7 +1023,19 @@ struct MfmaSelector
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
false
>
()
{
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32bf16
;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_16x16x8bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
true
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
...
@@ -760,7 +1044,18 @@ struct MfmaSelector
#endif
}
#if defined(CK_USE_AMD_MFMA_GFX940)
#if defined(__gfx950__)
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x32i8
;
}
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x64i8
;
}
#elif defined(__gfx942__)
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
...
...
@@ -832,8 +1127,8 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
is_single_rate_mfma
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
{
...
...
@@ -1135,7 +1430,13 @@ struct XdlopsGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
((
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
)
&&
KPack
<=
4
)
?
true
:
false
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
9ba504b6
...
...
@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
}
#if defined(__gfx942__)
#if defined(__gfx942__)
|| defined(__gfx950__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
9ba504b6
...
...
@@ -20,39 +20,25 @@
#define CK_USE_OCP_FP8 0
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__)
) &&
\
defined(__gfx1201__)
|| defined(__gfx950__)) &&
\
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx1200__) || defined(__gfx1201__)
|| defined(__gfx950__)
) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
bf8_fnuz_t
=
unsigned
_BitInt
(
8
);
typedef
unsigned
char
fp8_storage_t
;
/**
...
...
@@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
typename
conditional
<
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
...
...
@@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
return
__builtin_amdgcn_cvt_pk_f32_bf8
(
i16val
,
false
);
}
}
#endif
}
// namespace fp8_impl
...
...
@@ -378,7 +364,7 @@ struct bf8_ocp_t
__host__
explicit
operator
float
()
const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if
defined(__gfx950__) ||
defined(__gfx1200__) || defined(__gfx1201__)
return
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
);
#else
return
fp8_impl
::
cast_from_f8
<
float
,
wm
,
we
,
false
>
(
...
...
@@ -392,7 +378,7 @@ struct bf8_ocp_t
__host__
explicit
operator
_Float16
()
const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if
defined(__gfx950__) ||
defined(__gfx1200__) || defined(__gfx1201__)
return
static_cast
<
_Float16
>
(
fp8_impl
::
cast_to_f32_from_f8
<
default_interpret
>
(
this
->
data
));
#else
return
fp8_impl
::
cast_from_f8
<
_Float16
,
wm
,
we
,
false
>
(
...
...
@@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
conditional
<
using
T_bitwise
=
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
unsigned
short
int
,
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
typename
std
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
...
...
include/ck/utility/amd_xdlops.hpp
View file @
9ba504b6
...
...
@@ -5,7 +5,7 @@
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
...
...
@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f16
;
template
<
>
struct
intrin_mfma_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32f16
;
template
<
>
struct
intrin_mfma_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8f16
;
...
...
@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
};
// bfp16
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16bf16
;
template
<
>
struct
intrin_mfma_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x32bf16
;
template
<
>
struct
intrin_mfma_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x8bf16_1k
;
...
...
@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x32i8
;
template
<
>
struct
intrin_mfma_i32_32x32x32i8
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x32_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_16x16x64i8
;
template
<
>
struct
intrin_mfma_i32_16x16x64i8
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x64_i8
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif // defined(__gfx950__)
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_i32_32x32x16i8
;
...
...
@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x64f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
;
template
<
>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x128f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template
<
>
struct
intrin_mfma_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
f8x32_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
include/ck/utility/blkgemmpipe_scheduler.hpp
View file @
9ba504b6
...
...
@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL
);
printf
(
" A/B buffer load inst: %d, %d
\n
A/B LDS write inst: %d, %d
\n
A/B LDS read inst: "
"%d, %d
\n
C MFMA inst: %d
\n
"
,
"%d, %d
\n
C MFMA inst: %d
\n
"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d
\n
"
,
A_Buffer_Load_Inst_Num
,
B_Buffer_Load_Inst_Num
,
A_LDS_Write_Inst_Num
,
B_LDS_Write_Inst_Num
,
A_LDS_Read_Inst_Num
,
B_LDS_Read_Inst_Num
,
C_MFMA_Inst_Num
);
C_MFMA_Inst_Num
,
A_LDS_Read_Width
,
B_LDS_Read_Width
,
ALDSWriteWidth
,
BLDSWriteWidth
,
ABufferLoadWidth
,
BBufferLoadWidth
);
}
};
...
...
include/ck/utility/data_type.hpp
View file @
9ba504b6
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using
int8_t
=
signed
char
;
...
...
@@ -23,6 +24,296 @@ using std::byte;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f6_t
=
_BitInt
(
6
);
// e2m3 format
using
bf6_t
=
unsigned
_BitInt
(
6
);
// e3m2 format
struct
f4x2_pk_t
{
using
type
=
uint8_t
;
type
data
;
f4x2_pk_t
()
:
data
{
type
{}}
{}
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
(
Number
<
I
>
)
const
{
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
return
(
data
>>
4
);
}
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
{
return
(
x1
<<
4
)
|
(
x0
&
0b00001111
);
}
};
struct
f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
f6x16_pk_t
()
:
data
{
type
{}}
{}
f6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
f6x32_pk_t
()
:
data
{
type
{}}
{}
f6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
bf6x16_pk_t
()
:
data
{
type
{}}
{}
bf6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
bf6x32_pk_t
()
:
data
{
type
{}}
{}
bf6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
// custom data type - pack int4 data
struct
pk_i4_t
...
...
@@ -40,14 +331,15 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool
, f4_t, f6_t, bf6_t
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
||
is_same
<
T
,
f6_t
>::
value
||
is_same
<
T
,
bf6_t
>::
value
;
}
// vector_type
...
...
@@ -1370,12 +1662,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
using
type
=
f8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x16_pk_t
>
{
using
type
=
f6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x32_pk_t
>
{
using
type
=
f6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x16_pk_t
>
{
using
type
=
bf6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x32_pk_t
>
{
using
type
=
bf6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
...
...
@@ -1482,6 +1799,63 @@ struct non_native_vector_base<
}
};
// implementation for f6x16 and f6x32
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
12
||
sizeof
(
T
)
==
24
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on declared base type
using
element_t
=
typename
T
::
element_type
;
// select element_t based on declared element type
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
static
constexpr
size_t
size_factor
=
sizeof
(
data_t
)
/
sizeof
(
element_t
);
// f6x16: 12/4 = 3, f6x32: 24/4 = 6
using
data_v
=
element_t
__attribute__
((
ext_vector_type
(
N
*
size_factor
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
.
At
(
Number
<
0
>
{}))}
{
}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
};
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
...
...
@@ -2217,6 +2591,22 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// f4
using
f4x2_t
=
typename
vector_type
<
f4x2_pk_t
,
1
>::
type
;
using
f4x4_t
=
typename
vector_type
<
f4x2_pk_t
,
2
>::
type
;
using
f4x8_t
=
typename
vector_type
<
f4x2_pk_t
,
4
>::
type
;
using
f4x16_t
=
typename
vector_type
<
f4x2_pk_t
,
8
>::
type
;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
// f6
using
f6x16_t
=
typename
vector_type
<
f6x16_pk_t
,
1
>::
type
;
using
f6x32_t
=
typename
vector_type
<
f6x32_pk_t
,
1
>::
type
;
// bf6
using
bf6x16_t
=
typename
vector_type
<
bf6x16_pk_t
,
1
>::
type
;
using
bf6x32_t
=
typename
vector_type
<
bf6x32_pk_t
,
1
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
...
...
@@ -2571,6 +2961,118 @@ struct NumericLimits<bf8_ocp_t>
};
#endif
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
f6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x07
;
// 0b000111
static
constexpr
float
data_max_normal_number
=
7.5
;
static
constexpr
float
data_min_subnormal_number
=
0.125
;
__host__
__device__
static
constexpr
f6_t
Min
()
{
return
f6_t
(
binary_min_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Max
()
{
return
f6_t
(
binary_max_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Lowest
()
{
return
f6_t
(
binary_lowest_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MinSubnorm
()
{
return
f6_t
(
binary_min_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MaxSubnorm
()
{
return
f6_t
(
binary_max_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
bf6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x03
;
// 0b000011
static
constexpr
float
data_max_normal_number
=
28
;
static
constexpr
float
data_min_subnormal_number
=
0.0625
;
__host__
__device__
static
constexpr
bf6_t
Min
()
{
return
bf6_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Max
()
{
return
bf6_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Lowest
()
{
return
bf6_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
bf6_t
MinSubnorm
()
{
return
bf6_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
bf6_t
MaxSubnorm
()
{
return
bf6_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_bexp_t
>
{
static
constexpr
e8m0_bexp_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_bexp_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_bexp_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_bexp_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_bexp_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_bexp_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_bexp_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_bexp_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_bexp_t
Min
()
{
return
e8m0_bexp_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Max
()
{
return
e8m0_bexp_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
QuietNaN
()
{
return
e8m0_bexp_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_1
()
{
return
e8m0_bexp_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_2
()
{
return
e8m0_bexp_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_3
()
{
return
e8m0_bexp_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_135
()
{
return
e8m0_bexp_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_142
()
{
return
e8m0_bexp_t
(
binary_142
);
}
};
template
<
typename
T
>
struct
NumericUtils
{
...
...
@@ -2590,6 +3092,7 @@ struct NumericUtils<float>
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint32_t
;
};
...
...
@@ -2607,9 +3110,19 @@ struct NumericUtils<half_t>
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint16_t
;
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template
<
>
struct
NumericUtils
<
f8_fnuz_t
>
{
...
...
@@ -2617,6 +3130,7 @@ struct NumericUtils<f8_fnuz_t>
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
...
...
@@ -2626,6 +3140,7 @@ struct NumericUtils<bf8_fnuz_t>
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
...
...
@@ -2644,11 +3159,109 @@ struct NumericUtils<bf8_ocp_t>
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
struct
NumericUtils
<
f4_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
10
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b0000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b1000
;
static
constexpr
uint8_t
one_mask
=
0b0010
;
static
constexpr
uint8_t
set_sign_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b1111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b0001
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b1001
;
static
constexpr
bool
has_inf
=
false
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
f6_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
12
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000111
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100111
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
bf6_t
>
{
static
constexpr
int
exp
=
3
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
3
;
static
constexpr
uint32_t
sr_shift
=
11
;
static
constexpr
int
unbiased_exp_min
=
-
2
;
static
constexpr
int
unbiased_exp_max
=
4
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
7
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000011
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100011
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
e8m0_bexp_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
static
constexpr
int
mant
=
0
;
static
constexpr
int
bias
=
127
;
static
constexpr
int
unbiased_exp_min
=
-
127
;
static
constexpr
int
unbiased_exp_max
=
127
;
static
constexpr
int
biased_exp_min
=
0
;
static
constexpr
int
biased_exp_max
=
254
;
using
bitwise_type
=
uint8_t
;
};
}
// namespace ck
include/ck/utility/e8m0.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace
ck
{
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct
e8m0_bexp_t
{
using
type
=
uint8_t
;
type
data
;
constexpr
static
type
bias
=
127
;
constexpr
static
type
nan_mask
=
0xFF
;
__host__
__device__
constexpr
e8m0_bexp_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
int
init
)
:
data
{
static_cast
<
type
>
(
init
&
nan_mask
)}
{
}
__host__
__device__
explicit
constexpr
e8m0_bexp_t
(
float
scale
)
:
data
{
static_cast
<
type
>
((
bit_cast
<
uint32_t
>
(
scale
)
&
(
nan_mask
<<
23
))
>>
23
)}
{
}
__host__
__device__
explicit
constexpr
operator
float
()
const
{
if
(
data
==
nan_mask
||
data
==
0
)
{
uint32_t
bits
=
data
<<
1
;
bits
|=
1
;
bits
<<=
22
;
return
bit_cast
<
float
>
(
bits
);
}
else
{
uint32_t
bits
=
data
<<
23
;
return
bit_cast
<
float
>
(
bits
);
}
}
__host__
__device__
constexpr
bool
operator
==
(
const
e8m0_bexp_t
&
other
)
const
{
// strict IEEE compliance for NaN
return
data
==
other
.
data
&&
data
!=
nan_mask
;
}
__host__
__device__
constexpr
bool
is_nan
()
const
{
return
data
==
nan_mask
;
}
};
namespace
utils
{
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
);
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_bexp_t
>
(
e8m0_bexp_t
x
)
{
return
x
.
data
;
}
}
// namespace utils
}
// namespace ck
include/ck/utility/mxf4_utils.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
{
// no inf representation for ocp_e2m1_mxfp4
return
false
;
}
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f4_t
result
=
(
data
&
0b00001111
)
&
NumericUtils
<
f4_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f4_t
>
(
scale
,
data
))
return
0.0
f
;
f4_t
prepared_data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type
<
f4_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type_sr
<
f4_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf6_utils.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
f6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for fp6
return
false
;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
bf6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
bf6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for bf6
return
false
;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
f6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
bf6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
bf6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f6_t
>
(
scale
,
data
))
return
0.0
f
;
f6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
bf6_t
>
(
scale
,
data
))
return
0.0
f
;
bf6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
bf6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type
<
f6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type
<
f6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type
<
bf6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type
<
bf6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type_sr
<
f6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type_sr
<
f6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type_sr
<
bf6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type_sr
<
bf6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf8_utils.hpp
0 → 100644
View file @
9ba504b6
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
namespace
fp8_impl
{
#if CK_MX_FP8_CVT_FAST_PATH
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8_scaled
(
float
scale
,
fp8_storage_t
v
)
{
union
{
unsigned
int
i32val
;
unsigned
char
i8val
[
4
];
}
val
;
val
.
i8val
[
0
]
=
v
;
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_f32_fp8
(
val
.
i32val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_f32_bf8
(
val
.
i32val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float2_t
cast_to_f32x2_from_f8x2_scaled
(
float
scale
,
fp8x2_storage_t
v
)
{
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp8
(
i16val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_bf8
(
i16val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8_storage_t
cast_to_f8_from_f32_scaled
(
float
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
fp8_storage_t
i8data
;
union
{
float
fval
;
unsigned
int
i32val
;
}
val
;
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
fp8_storage_t
v4i8
[
4
];
}
ret
{};
// unsigned int ival = 0;
val
.
fval
=
v
;
if
constexpr
(
stochastic_rounding
)
{
ret
.
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
)
:
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
);
i8data
=
ret
.
v4i8
[
0
];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
i8data
=
ret
.
v4i8
[
0
];
}
return
i8data
;
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8x2_storage_t
cast_to_f8_from_f32_scaled
(
float2_t
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
StaticallyIndexedArray
<
fp8x2_storage_t
,
2
>
v2f8x2
;
}
ret
{};
if
constexpr
(
stochastic_rounding
)
{
fp8x2_storage_t
f8x2
;
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
else
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
return
f8x2
;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
return
ret
.
v2f8x2
(
Number
<
0
>
{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
{
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
{
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
}
// namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_sr
(
X
x
,
float
scale
);
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_rne
(
X
x
,
float
scale
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_rne
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_rne
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_rne
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_rne
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_rne
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_rne
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_rne
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_rne
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_rne
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_sr
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_sr
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_sr
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_sr
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_sr
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_sr
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_sr
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_sr
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_sr
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
}
// namespace ck
include/ck/utility/mxfp_utils.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
::
utils
{
union
cvt
{
float
value_float
;
uint32_t
value_bitwise
;
};
template
<
typename
DTYPE
>
inline
bool
getDataHasInf
()
{
return
DTYPE
::
dataInfo
.
hasInf
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
{
x
>>=
NumericUtils
<
T
>::
mant
;
x
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
);
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_subnormal
(
T
x
)
{
return
get_exponent_value
<
T
>
(
x
)
==
0
;
}
template
<
typename
T
>
__host__
__device__
inline
double
get_mantissa_value
(
T
x
)
{
double
mantissa
=
is_subnormal
<
T
>
(
x
)
?
0.0
f
:
1.0
f
;
for
(
uint
i
=
0
;
i
<
NumericUtils
<
T
>::
mant
;
i
++
)
{
mantissa
+=
std
::
pow
(
2
,
-
int32_t
((
NumericUtils
<
T
>::
mant
-
i
)))
*
(
x
&
0b1
);
x
>>=
1
;
}
return
mantissa
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
get_data_has_inf
()
{
return
NumericUtils
<
T
>::
has_inf
;
}
template
<
typename
T
>
__host__
__device__
float
convert_to_float
(
T
data
,
int
scale_exp
)
{
float
d_sign
=
std
::
pow
(
-
1
,
static_cast
<
float
>
(
data
>>
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
)));
float
d_exp
;
if
(
is_subnormal
<
T
>
(
data
))
d_exp
=
std
::
pow
(
2
,
1
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
else
d_exp
=
std
::
pow
(
2
,
get_exponent_value
<
T
>
(
data
)
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
float
d_mant
=
get_mantissa_value
<
T
>
(
data
);
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_bexp_t
>::
bias
))));
return
data_value
*
scale_value
;
}
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type_sr
(
float
value
,
uint32_t
seed
);
template
<
typename
T
>
inline
T
convert_to_type
(
float
value
)
{
using
bitwise_type
=
typename
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint32_t
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
bitwise_type
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
bitwise_type
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
bitwise_type
mantissa
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
{
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
}
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
const
int
mfmt
=
NumericUtils
<
float
>::
mant
;
uint32_t
x
;
x
=
bit_cast
<
uint32_t
>
(
value
);
uint32_t
head
,
mantissa
;
int32_t
exponent
,
bias
;
uint32_t
sign
;
head
=
x
&
NumericUtils
<
float
>::
head_mask
;
mantissa
=
x
&
NumericUtils
<
float
>::
mant_mask
;
exponent
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
sign
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
bias
=
NumericUtils
<
float
>::
bias
;
if
(
x
==
0
)
{
return
0b0
;
}
const
int
mini_bias
=
NumericUtils
<
T
>::
bias
;
const
int
mini_denormal_act_exponent
=
1
-
mini_bias
;
int
act_exponent
,
out_exponent
,
exponent_diff
;
bool
is_subnorm
=
false
;
if
(
exponent
==
0
)
{
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
mini_denormal_act_exponent
)
{
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
exponent_diff
=
0
;
}
mantissa
+=
(
1UL
<<
mfmt
);
}
auto
shift_amount
=
(
mfmt
-
NumericUtils
<
T
>::
mant
+
exponent_diff
);
shift_amount
=
(
shift_amount
>=
64
)
?
63
:
shift_amount
;
bool
midpoint
=
(
mantissa
&
((
1UL
<<
shift_amount
)
-
1
))
==
(
1UL
<<
(
shift_amount
-
1
));
float
min_subnorm
=
NumericLimits
<
T
>::
DataMinSubnorm
()
*
(
sign
?
-
1
:
1
);
if
(
is_subnorm
&&
std
::
abs
(
value
)
<
std
::
abs
(
min_subnorm
))
{
// closer to 0
if
(
std
::
abs
(
value
)
<=
std
::
abs
(
min_subnorm
-
value
))
return
0
;
else
return
1
|
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
));
}
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
out_exponent
=
(
act_exponent
+
exponent_diff
)
+
mini_bias
-
(
implicit_one
?
0
:
1
);
uint32_t
drop_mask
=
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
))
-
1
;
bool
odd
=
mantissa
&
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
));
mantissa
+=
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
&
drop_mask
;
if
(
out_exponent
==
0
)
{
if
((
1UL
<<
mfmt
)
&
mantissa
)
{
out_exponent
=
1
;
}
}
else
{
if
((
1UL
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
NumericUtils
<
T
>::
mant
);
if
(
out_exponent
==
0
&&
mantissa
==
0
)
{
return
0
;
}
mantissa
&=
(
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
;
return
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
out_exponent
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
template
<
typename
T
>
inline
T
convert_to_type_sr
(
float
value
,
uint32_t
seed
)
{
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
T
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
T
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
double
d_max_value
=
static_cast
<
double
>
(
max_value
);
double
d_actual_max
=
static_cast
<
double
>
(
actual_max
);
double
d_value
=
static_cast
<
double
>
(
value
);
double
d_is
=
std
::
abs
(
d_max_value
-
d_actual_max
);
double
d_seed
=
static_cast
<
double
>
(
seed
);
double
d_prob
=
1.0
f
-
(
std
::
abs
(
d_value
-
d_max_value
)
/
d_is
);
// prob to round down
double
thresh
=
UINT_MAX
*
d_prob
;
if
(
!
get_data_has_inf
<
T
>
()
||
d_seed
<=
thresh
)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return
sign
==
0
?
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
uint32_t
f32
=
bit_cast
<
uint32_t
>
(
value
);
auto
f32_mant
=
f32
&
NumericUtils
<
float
>::
mant_mask
;
auto
head
=
f32
&
NumericUtils
<
float
>::
head_mask
;
auto
f32_exp
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
auto
sign_bit
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
auto
sign
=
sign_bit
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
);
f32_exp
=
static_cast
<
int32_t
>
(
f32_exp
)
-
NumericUtils
<
float
>::
bias
;
int32_t
exp
=
f32_exp
;
auto
mant
=
f32_mant
;
bool
subnorm
=
false
;
if
(
f32
==
0
)
return
0b0
;
if
(
exp
>=
NumericUtils
<
T
>::
unbiased_exp_min
)
{
mant
=
f32_mant
;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else
if
(
exp
<
NumericUtils
<
T
>::
unbiased_exp_min
&&
NumericUtils
<
T
>::
exp
<
NumericUtils
<
float
>::
exp
)
{
subnorm
=
true
;
auto
diff
=
static_cast
<
uint32_t
>
(
NumericUtils
<
T
>::
unbiased_exp_min
-
exp
);
if
(
diff
>=
32
)
{
mant
=
0
;
f32_mant
=
0
;
}
else
{
f32_mant
|=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
;
f32_mant
>>=
diff
;
}
exp
=
0
;
mant
=
f32_mant
;
}
uint32_t
sr_shift
=
NumericUtils
<
T
>::
sr_shift
;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant
+=
seed
>>
sr_shift
;
// Increment exponent when mantissa overflows due to rounding
if
(
mant
>=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
)
++
exp
;
mant
>>=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
auto
biased_exp
=
static_cast
<
uint32_t
>
(
exp
);
if
(
!
subnorm
)
biased_exp
=
static_cast
<
uint32_t
>
(
exp
+
NumericUtils
<
T
>::
bias
);
biased_exp
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
auto
val
=
sign
|
biased_exp
<<
NumericUtils
<
T
>::
mant
|
mant
;
return
val
;
}
}
// namespace ck::utils
include/ck/utility/scaled_type_convert.hpp
0 → 100644
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace
ck
{
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
#if CK_USE_OCP_FP8
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#else
__host__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#endif
// convert f8_ocp_t to fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f8_ocp_t
>
(
e8m0_bexp_t
scale
,
f8_ocp_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
f8_ocp_t
>
(
e8m0_bexp_t
scale
,
f8_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8_scaled
<
f8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
data
);
#else
return
type_convert
<
float
>
(
scale
)
*
type_convert
<
float
>
(
x
);
#endif
}
// convert bf8_ocp_t to fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf8_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8_ocp_t
x
)
#else
inline
__host__
float
scaled_type_convert
<
float
,
bf8_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32_from_f8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
data
);
#else
return
type_convert
<
float
>
(
scale
)
*
type_convert
<
float
>
(
x
);
#endif
}
// convert 2 x f8_ocp_t to 2 x fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x2_ocp_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x2_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2_scaled
<
f8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
f8_ocp_t
>
()[
Number
<
0
>
{}]),
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
f8_ocp_t
>
()[
Number
<
1
>
{}])};
#endif
}
// convert 2 x bf8_ocp_t to 2 x fp32
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x2_ocp_t
x
)
#else
inline
__host__
float2_t
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x2_ocp_t
x
)
#endif
{
#if CK_MX_FP8_CVT_FAST_PATH
return
fp8_impl
::
cast_to_f32x2_from_f8x2_scaled
<
bf8_ocp_t
::
default_interpret
>
(
type_convert
<
float
>
(
scale
),
x
.
AsType
<
fp8_impl
::
fp8x2_storage_t
>
()[
Number
<
0
>
{}]);
#else
return
float2_t
{
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
bf8_ocp_t
>
()[
Number
<
0
>
{}]),
scaled_type_convert
<
float
>
(
scale
,
x
.
AsType
<
bf8_ocp_t
>
()[
Number
<
1
>
{}])};
#endif
}
// convert 16 x f8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float16_t
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x16_ocp_t
x
)
#else
inline
__host__
float16_t
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x16_ocp_t
x
)
#endif
{
union
{
f8x16_ocp_t
f8_1x16
;
f8x2_ocp_t
f8_2x8
[
8
];
}
in
{
x
};
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_2x8
[
i
]
=
scaled_type_convert
<
float2_t
,
f8x2_ocp_t
>
(
scale
,
in
.
f8_2x8
[
i
]);
});
return
out
.
float_1x16
;
}
// convert 16 x bf8_ocp_t to 16 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float16_t
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x16_ocp_t
x
)
#else
inline
__host__
float16_t
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x16_ocp_t
x
)
#endif
{
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
in
{
x
};
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_2x8
[
i
]
=
scaled_type_convert
<
float2_t
,
bf8x2_ocp_t
>
(
scale
,
in
.
bf8_2x8
[
i
]);
});
return
out
.
float_1x16
;
}
// convert 32 x f8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x32_ocp_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
f8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
f8x32_ocp_t
x
)
#endif
{
union
{
f8x32_ocp_t
f8_1x32
;
f8x16_ocp_t
f8_16x2
[
2
];
}
in
{
x
};
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_16x2
[
i
]
=
scaled_type_convert
<
float16_t
,
f8x16_ocp_t
>
(
scale
,
in
.
f8_16x2
[
i
]);
});
return
out
.
float_1x32
;
}
// convert 32 x bf8_ocp_t to 32 x fp32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x32_ocp_t
x
)
#else
inline
__host__
float32_t
scaled_type_convert
<
float32_t
,
bf8x32_ocp_t
>
(
e8m0_bexp_t
scale
,
bf8x32_ocp_t
x
)
#endif
{
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
in
{
x
};
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_16x2
[
i
]
=
scaled_type_convert
<
float16_t
,
bf8x16_ocp_t
>
(
scale
,
in
.
bf8_16x2
[
i
]);
});
return
out
.
float_1x32
;
}
// convert fp32 to fp8
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32 to bf8
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x2 to fp8x2
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#else
inline
__host__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x2 to bf8x2
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#else
inline
__host__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
mxf8_convert_rne
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a 6-bit floating-point value (f6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The f6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
scale
,
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a 6-bit floating-point value (bf6_t) to a 32-bit float,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The bf6_t value to be converted.
* @return The converted 32-bit float representation of the input.
*/
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
scale
,
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (f6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (f6_t).
*/
template
<
>
inline
__host__
__device__
f6_t
scaled_type_convert
<
f6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template
<
>
inline
__host__
__device__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param x The float value to convert.
* @return The converted 6-bit floating-point value (bf6_t).
*/
template
<
>
inline
__host__
__device__
bf6_t
scaled_type_convert
<
bf6_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template
<
>
inline
__host__
__device__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
9ba504b6
...
...
@@ -5,15 +5,39 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/mxf4_utils.hpp"
#include "ck/utility/mxf6_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
namespace
{
namespace
details
{
[[
maybe_unused
]]
__host__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
half2_t
vector_res
;
vector_res
.
x
=
x
.
x
+
y
.
x
;
vector_res
.
y
=
x
.
y
+
y
.
y
;
return
vector_res
;
}
[[
maybe_unused
]]
__device__
half2_t
pk_add_f16
(
const
half2_t
&
x
,
const
half2_t
&
y
)
{
return
amd_assembly_pk_add_f16
(
x
,
y
);
}
}
// namespace details
}
// namespace
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
...
@@ -520,13 +544,51 @@ template <>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f32
=
ck
::
type_convert
<
float
>
(
x_l
);
auto
h_f32
=
ck
::
type_convert
<
float
>
(
x_h
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t
res
=
{
x_h
,
x_l
};
#elif
float2_t
res
=
{
x_l
,
x_h
};
#endif
return
res
;
}
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
#else
uint32_t
i4s
=
((
x_u8
&
0xf0
)
<<
12
)
|
(
x_u8
&
0xf
);
#endif
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
{
l_f32
,
h_f32
};
return
details
::
pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
float
x_l
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_h
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_h
),
type_convert
<
bhalf_t
>
(
x_l
)};
#else
bhalf2_t
res
=
{
type_convert
<
bhalf_t
>
(
x_l
),
type_convert
<
bhalf_t
>
(
x_h
)};
#endif
return
res
;
}
template
<
>
...
...
@@ -647,6 +709,1278 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
#endif
}
// convert fp32 to fp4 with rounding to nearest even
inline
__host__
__device__
f4_t
f4_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
,
x
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f4_t
>
(
x
/
scale
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4_t
f4_array
[
4
];
}
value
{
0
};
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{{
x
}};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
float2_array
,
rng
,
scale
,
0
);
return
value
.
f4_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
/
scale
,
rng
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{{
0
}};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
tmp_values
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
type_convert
<
f4_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float
scale
=
1.0
f
;
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
#endif
}
/**
* @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts it
* to the 6-bit floating-point format (f6_t).
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
*
* Divides the input by the specified scale, then performs saturation and conversion
* to f6_t based on a pseudo-randomly generated seed.
*
* @param x The input float value.
* @param scale A scaling factor applied to `x` before conversion.
* @return The converted f6_t value.
*/
inline
__host__
__device__
f6_t
f6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6_t value.
*/
template
<
>
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
*
* Interprets an f6_t value as a float using the default scale factor of 1.
*
* @param x The 6-bit float (f6_t) value to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Converts a float to the 6-bit BF6 type using stochastic rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float value to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6_t value.
*/
inline
__host__
__device__
bf6_t
bf6_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
/**
* @brief Specializes float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float value to convert.
* @return Converted bf6_t value.
*/
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
* Interprets the bf6_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6_t value to convert.
* @return The float representation of the given bf6_t value.
*/
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
#ifndef CK_CODE_GEN_RTC
template
<
typename
Y
,
typename
X
,
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
...
...
include/ck_tile/core.hpp
View file @
9ba504b6
...
...
@@ -32,6 +32,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
...
...
include/ck_tile/core/arch/generic_memory_space_atomic.hpp
View file @
9ba504b6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/vector_type.hpp"
...
...
@@ -8,16 +8,75 @@
namespace
ck_tile
{
CK_TILE_HOST_DEVICE
bf16_t
add_bf16_t
(
const
bf16_t
&
a
,
const
bf16_t
&
b
)
template
<
typename
T
,
typename
ComputeType
>
CK_TILE_HOST_DEVICE
T
add
(
const
T
&
a
,
const
T
&
b
)
{
return
type_convert
<
bf16_t
>
(
type_convert
<
float
>
(
a
)
+
type_convert
<
float
>
(
b
));
return
type_convert
<
T
>
(
type_convert
<
ComputeType
>
(
a
)
+
type_convert
<
ComputeType
>
(
b
));
}
CK_TILE_HOST_DEVICE
bf16x2_t
add_bf16x2_t
(
const
bf16x2_t
&
a
,
const
bf16x2_t
&
b
)
{
bf16x2_t
rtn
;
rtn
[
0
]
=
add_bf16_t
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add_bf16_t
(
a
[
1
],
b
[
1
]);
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf16x4_t
add_bf16x4_t
(
const
bf16x4_t
&
a
,
const
bf16x4_t
&
b
)
{
bf16x4_t
rtn
;
rtn
[
0
]
=
add
<
bf16_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf16_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf16_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf16_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x4_t
add_fp8x4_t
(
const
fp8x4_t
&
a
,
const
fp8x4_t
&
b
)
{
fp8x4_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
fp8x8_t
add_fp8x8_t
(
const
fp8x8_t
&
a
,
const
fp8x8_t
&
b
)
{
fp8x8_t
rtn
;
rtn
[
0
]
=
add
<
fp8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
fp8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
fp8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
fp8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
fp8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
fp8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
fp8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
fp8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x4_t
add_bf8x4_t
(
const
bf8x4_t
&
a
,
const
bf8x4_t
&
b
)
{
bf8x4_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
return
rtn
;
}
CK_TILE_HOST_DEVICE
bf8x8_t
add_bf8x8_t
(
const
bf8x8_t
&
a
,
const
bf8x8_t
&
b
)
{
bf8x8_t
rtn
;
rtn
[
0
]
=
add
<
bf8_t
,
float
>
(
a
[
0
],
b
[
0
]);
rtn
[
1
]
=
add
<
bf8_t
,
float
>
(
a
[
1
],
b
[
1
]);
rtn
[
2
]
=
add
<
bf8_t
,
float
>
(
a
[
2
],
b
[
2
]);
rtn
[
3
]
=
add
<
bf8_t
,
float
>
(
a
[
3
],
b
[
3
]);
rtn
[
4
]
=
add
<
bf8_t
,
float
>
(
a
[
4
],
b
[
4
]);
rtn
[
5
]
=
add
<
bf8_t
,
float
>
(
a
[
5
],
b
[
5
]);
rtn
[
6
]
=
add
<
bf8_t
,
float
>
(
a
[
6
],
b
[
6
]);
rtn
[
7
]
=
add
<
bf8_t
,
float
>
(
a
[
7
],
b
[
7
]);
return
rtn
;
}
...
...
@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf16x4_t
>
(
bf16x4_t
*
p_dst
,
bf16x4_t
const
&
x
)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union
U64BF164_ADDR
{
uint64_t
*
u64_a
;
bf16x4_t
*
bf164_a
;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union
U64BF164
{
uint64_t
u64
;
bf16x4_t
bf164
;
};
U64BF164_ADDR
addr
;
addr
.
bf164_a
=
p_dst
;
// interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164
cur_v
;
cur_v
.
u64
=
*
addr
.
u64_a
;
U64BF164
new_v_union
;
uint64_t
old_v
,
new_v
;
do
{
// old 64 bits
old_v
=
cur_v
.
u64
;
// Add elementwise in bf16
new_v_union
.
bf164
=
add_bf16x4_t
(
cur_v
.
bf164
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt the 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x4_t
>
(
fp8x4_t
*
p_dst
,
const
fp8x4_t
&
x
)
{
union
U32FP84_ADDR
{
uint32_t
*
u32_a
;
fp8x4_t
*
fp84_a
;
};
union
U32FP84
{
uint32_t
u32
;
fp8x4_t
fp84
;
};
U32FP84_ADDR
dword_addr
;
U32FP84
cur_v
;
U32FP84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
fp84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
fp84
=
add_fp8x4_t
(
cur_v
.
fp84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x4_t
>
(
bf8x4_t
*
p_dst
,
const
bf8x4_t
&
x
)
{
union
U32BF84_ADDR
{
uint32_t
*
u32_a
;
bf8x4_t
*
bf84_a
;
};
union
U32BF84
{
uint32_t
u32
;
bf8x4_t
bf84
;
};
U32BF84_ADDR
dword_addr
;
U32BF84
cur_v
;
U32BF84
new_
;
uint32_t
old_v
,
new_v
;
dword_addr
.
bf84_a
=
p_dst
;
cur_v
.
u32
=
*
dword_addr
.
u32_a
;
do
{
old_v
=
cur_v
.
u32
;
new_
.
bf84
=
add_bf8x4_t
(
cur_v
.
bf84
,
x
);
new_v
=
new_
.
u32
;
cur_v
.
u32
=
atomicCAS
(
dword_addr
.
u32_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u32
!=
old_v
);
}
//
// Atomic add for fp8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
fp8x8_t
>
(
fp8x8_t
*
p_dst
,
fp8x8_t
const
&
x
)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union
U64FP88_ADDR
{
uint64_t
*
u64_a
;
// pointer to 64-bit integer
fp8x8_t
*
fp88_a
;
// pointer to fp8x8_t
};
union
U64FP88
{
uint64_t
u64
;
fp8x8_t
fp88
;
};
U64FP88_ADDR
dword_addr
;
U64FP88
cur_v
;
U64FP88
new_v_union
;
uint64_t
old_v
,
new_v
;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr
.
fp88_a
=
p_dst
;
// Initial read of 64 bits from memory
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union
.
fp88
=
add_fp8x8_t
(
cur_v
.
fp88
,
x
);
new_v
=
new_v_union
.
u64
;
// Attempt 64-bit CAS
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
//
// Atomic add for bf8x8_t
//
template
<
>
CK_TILE_DEVICE
void
atomic_add
<
bf8x8_t
>
(
bf8x8_t
*
p_dst
,
bf8x8_t
const
&
x
)
{
union
U64BF88_ADDR
{
uint64_t
*
u64_a
;
bf8x8_t
*
bf88_a
;
};
union
U64BF88
{
uint64_t
u64
;
bf8x8_t
bf88
;
};
U64BF88_ADDR
dword_addr
;
U64BF88
cur_v
;
U64BF88
new_v_union
;
uint64_t
old_v
,
new_v
;
dword_addr
.
bf88_a
=
p_dst
;
// Read the original 64 bits
cur_v
.
u64
=
*
dword_addr
.
u64_a
;
do
{
old_v
=
cur_v
.
u64
;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union
.
bf88
=
add_bf8x8_t
(
cur_v
.
bf88
,
x
);
new_v
=
new_v_union
.
u64
;
// 64-bit CAS loop
cur_v
.
u64
=
atomicCAS
(
dword_addr
.
u64_a
,
old_v
,
new_v
);
}
while
(
cur_v
.
u64
!=
old_v
);
}
template
<
typename
T
,
index_t
N
>
CK_TILE_DEVICE
void
atomic_add_g
(
T
*
p_dst
,
const
thread_buffer
<
T
,
N
>&
x
)
{
...
...
@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(
std
::
is_same
<
T
,
uint32_t
>::
value
&&
(
N
==
1
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
4
||
N
==
8
||
N
==
16
)),
"The granularity of the thread buffer is unsupported on the hardware!"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
...
...
@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
}
else
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x2_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x2_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x2_t
>()[
I1
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
}
else
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf16x4_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf16x4_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf16x4_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
fp8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
fp8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
fp8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
fp8x8_t
>()[
I1
]);
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
bf8_t
>::
value
)
{
if
constexpr
(
N
==
4
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x4_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x4_t
>()[
I0
]);
}
if
constexpr
(
N
==
8
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
}
if
constexpr
(
N
==
16
)
{
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
),
x
.
template
get_as
<
bf8x8_t
>()[
I0
]);
atomic_add
(
c_style_pointer_cast
<
bf8x8_t
*>
(
p_dst
)
+
1
,
x
.
template
get_as
<
bf8x8_t
>()[
I1
]);
}
}
}
...
...
include/ck_tile/core/config.hpp
View file @
9ba504b6
...
...
@@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment