Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
24f6f4ab
Unverified
Commit
24f6f4ab
authored
Feb 01, 2025
by
Illia Silin
Committed by
GitHub
Feb 01, 2025
Browse files
Merge pull request #284 from ROCm/lwpck-2619
Add FP6/BF6 vector type support
parents
5eb4c3d1
df1bad99
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1445 additions
and
128 deletions
+1445
-128
CMakeLists.txt
CMakeLists.txt
+4
-0
include/ck/config.h.in
include/ck/config.h.in
+4
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+357
-2
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+1
-1
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+284
-55
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+421
-46
test/data_type/test_bf6.cpp
test/data_type/test_bf6.cpp
+170
-0
test/data_type/test_fp4.cpp
test/data_type/test_fp4.cpp
+36
-24
test/data_type/test_fp6.cpp
test/data_type/test_fp6.cpp
+168
-0
No files found.
CMakeLists.txt
View file @
24f6f4ab
...
@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
...
@@ -210,6 +210,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
add_definitions
(
-DCK_USE_FNUZ_FP8
)
add_definitions
(
-DCK_USE_FNUZ_FP8
)
set
(
CK_USE_FNUZ_FP8
"ON"
)
set
(
CK_USE_FNUZ_FP8
"ON"
)
endif
()
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx950"
)
add_definitions
(
-DCK_USE_NATIVE_MX_SUPPORT
)
set
(
CK_USE_NATIVE_MX_SUPPORT
"ON"
)
endif
()
option
(
CK_USE_FP8_ON_UNSUPPORTED_ARCH
"Enable FP8 GEMM instances on older architectures"
OFF
)
option
(
CK_USE_FP8_ON_UNSUPPORTED_ARCH
"Enable FP8 GEMM instances on older architectures"
OFF
)
if
(
CK_USE_FP8_ON_UNSUPPORTED_ARCH
AND
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx90a"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx908"
))
if
(
CK_USE_FP8_ON_UNSUPPORTED_ARCH
AND
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx90a"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx908"
))
...
...
include/ck/config.h.in
View file @
24f6f4ab
...
@@ -131,6 +131,10 @@
...
@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on
// clang-format on
#endif // CK_CONFIG_H_IN
#endif // CK_CONFIG_H_IN
include/ck/utility/data_type.hpp
View file @
24f6f4ab
...
@@ -24,8 +24,9 @@ struct f4x2_pk_t
...
@@ -24,8 +24,9 @@ struct f4x2_pk_t
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
()
const
__host__
__device__
inline
type
unpack
(
Number
<
I
>
)
const
{
{
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
return
data
&
0b00001111
;
else
else
...
@@ -38,6 +39,270 @@ struct f4x2_pk_t
...
@@ -38,6 +39,270 @@ struct f4x2_pk_t
}
}
};
};
struct
f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
f6x16_pk_t
()
:
data
{
type
{}}
{}
f6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
f6x32_pk_t
()
:
data
{
type
{}}
{}
f6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
bf6x16_pk_t
()
:
data
{
type
{}}
{}
bf6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
bf6x32_pk_t
()
:
data
{
type
{}}
{}
bf6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
// custom data type - pack int4 data
// custom data type - pack int4 data
struct
pk_i4_t
struct
pk_i4_t
{
{
...
@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x)
...
@@ -56,7 +321,7 @@ inline constexpr auto next_pow2(uint32_t x)
}
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool
, f4_t, f6_t, bf6_t
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
inline
constexpr
bool
is_native_type
()
{
{
...
@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
...
@@ -1387,12 +1652,37 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
{
using
type
=
f8_ocp_t
::
data_type
;
using
type
=
f8_ocp_t
::
data_type
;
};
};
template
<
>
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
{
using
type
=
bf8_ocp_t
::
data_type
;
using
type
=
bf8_ocp_t
::
data_type
;
};
};
template
<
>
struct
nnvb_data_t_selector
<
f6x16_pk_t
>
{
using
type
=
f6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x32_pk_t
>
{
using
type
=
f6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x16_pk_t
>
{
using
type
=
bf6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x32_pk_t
>
{
using
type
=
bf6x32_pk_t
::
type
;
};
template
<
>
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
{
...
@@ -1499,6 +1789,63 @@ struct non_native_vector_base<
...
@@ -1499,6 +1789,63 @@ struct non_native_vector_base<
}
}
};
};
// implementation for f6x16 and f6x32
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
12
||
sizeof
(
T
)
==
24
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on declared base type
using
element_t
=
typename
T
::
element_type
;
// select element_t based on declared element type
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
static
constexpr
size_t
size_factor
=
sizeof
(
data_t
)
/
sizeof
(
element_t
);
// f6x16: 12/4 = 3, f6x32: 24/4 = 6
using
data_v
=
element_t
__attribute__
((
ext_vector_type
(
N
*
size_factor
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
.
At
(
Number
<
0
>
{}))}
{
}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dxN
;
// XXX this should cause an error
}
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dTxN
[
Number
<
0
>
{}];
}
else
{
return
data_
.
dTxN
;
// XXX this should cause an error
}
}
};
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
struct
scalar_type
<
non_native_vector_base
<
T
,
N
>>
;
...
@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
...
@@ -2242,6 +2589,14 @@ using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
// f6
using
f6x16_t
=
typename
vector_type
<
f6x16_pk_t
,
1
>::
type
;
using
f6x32_t
=
typename
vector_type
<
f6x32_pk_t
,
1
>::
type
;
// bf6
using
bf6x16_t
=
typename
vector_type
<
bf6x16_pk_t
,
1
>::
type
;
using
bf6x32_t
=
typename
vector_type
<
bf6x32_pk_t
,
1
>::
type
;
// pack int4
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
...
...
include/ck/utility/mxfp_utils.hpp
View file @
24f6f4ab
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/utility/scaled_type_convert.hpp
View file @
24f6f4ab
...
@@ -6,11 +6,21 @@
...
@@ -6,11 +6,21 @@
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/mxf8_utils.hpp"
#include "ck/utility/mxf8_utils.hpp"
#ifdef CK_USE_NATIVE_MX_SUPPORT
#define CK_USE_NATIVE_MX_SUPPORT 1
#else
#define CK_USE_NATIVE_MX_SUPPORT 0
#endif
namespace
ck
{
namespace
ck
{
// Declare a template function for scaled conversion
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
#if CK_USE_OCP_FP8
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#else
__host__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
#endif
// convert f8_ocp_t to fp32
// convert f8_ocp_t to fp32
template
<
>
template
<
>
...
@@ -200,27 +210,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
...
@@ -200,27 +210,13 @@ inline __host__ float32_t scaled_type_convert<float32_t, bf8x32_ocp_t>(e8m0_bexp
return
out
.
float_1x32
;
return
out
.
float_1x32
;
}
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert fp32 to fp8
// convert fp32 to fp8
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
inline
__host__
__device__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#else
inline
__host__
f8_ocp_t
scaled_type_convert
<
f8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -231,8 +227,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
...
@@ -231,8 +227,12 @@ inline __host__ __device__ f8_ocp_t scaled_type_convert<f8_ocp_t, float>(e8m0_be
// convert fp32 to bf8
// convert fp32 to bf8
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
float
x
)
#else
inline
__host__
bf8_ocp_t
scaled_type_convert
<
bf8_ocp_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -243,8 +243,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
...
@@ -243,8 +243,12 @@ inline __host__ __device__ bf8_ocp_t scaled_type_convert<bf8_ocp_t, float>(e8m0_
// convert fp32x2 to fp8x2
// convert fp32x2 to fp8x2
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
float2_t
x
)
#else
inline
__host__
f8x2_ocp_t
scaled_type_convert
<
f8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -254,8 +258,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
...
@@ -254,8 +258,13 @@ inline __host__ __device__ f8x2_ocp_t scaled_type_convert<f8x2_ocp_t, float2_t>(
}
}
// convert fp32x2 to bf8x2
// convert fp32x2 to bf8x2
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
float2_t
x
)
#else
inline
__host__
bf8x2_ocp_t
scaled_type_convert
<
bf8x2_ocp_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -267,8 +276,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
...
@@ -267,8 +276,13 @@ inline __host__ __device__ bf8x2_ocp_t scaled_type_convert<bf8x2_ocp_t, float2_t
// convert fp32x16 to fp8x16
// convert fp32x16 to fp8x16
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x16_ocp_t
inline
__host__
__device__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
f8x16_ocp_t
scaled_type_convert
<
f8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -280,8 +294,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
...
@@ -280,8 +294,13 @@ scaled_type_convert<f8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x16 to bf8x16
// convert fp32x16 to bf8x16
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x16_ocp_t
inline
__host__
__device__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#else
inline
__host__
bf8x16_ocp_t
scaled_type_convert
<
bf8x16_ocp_t
,
float16_t
>
(
e8m0_bexp_t
scale
,
float16_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -293,8 +312,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
...
@@ -293,8 +312,13 @@ scaled_type_convert<bf8x16_ocp_t, float16_t>(e8m0_bexp_t scale, float16_t x)
// convert fp32x32 to fp8x32
// convert fp32x32 to fp8x32
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
f8x32_ocp_t
inline
__host__
__device__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
f8x32_ocp_t
scaled_type_convert
<
f8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
f8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -306,8 +330,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
...
@@ -306,8 +330,13 @@ scaled_type_convert<f8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
// convert fp32x32 to bf8x32
// convert fp32x32 to bf8x32
// @note Host version gives compilation error. Requires extra compiler options.
// @note Host version gives compilation error. Requires extra compiler options.
template
<
>
template
<
>
#if CK_USE_OCP_FP8
inline
__host__
__device__
bf8x32_ocp_t
inline
__host__
__device__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#else
inline
__host__
bf8x32_ocp_t
scaled_type_convert
<
bf8x32_ocp_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
#endif
{
{
#if CK_USE_SR_F8_CONVERSION
#if CK_USE_SR_F8_CONVERSION
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
return
mxf8_convert_sr
<
bf8x32_ocp_t
>
(
x
,
type_convert
<
float
>
(
scale
));
...
@@ -316,6 +345,26 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
...
@@ -316,6 +345,26 @@ scaled_type_convert<bf8x32_ocp_t, float32_t>(e8m0_bexp_t scale, float32_t x)
#endif
#endif
}
}
// activate for architectures with native MX support
#if CK_USE_NATIVE_MX_SUPPORT
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
...
@@ -330,9 +379,10 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
...
@@ -330,9 +379,10 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
value
.
f4x2_array
[
0
]
=
x
;
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
#else
float2_t
ret
{
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
()),
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
())};
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
return
ret
;
#endif
#endif
}
}
...
@@ -467,72 +517,104 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
...
@@ -467,72 +517,104 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}));
return
float_values
.
float32_array
;
return
float_values
.
float32_array
;
#endif
#endif
...
@@ -584,8 +666,59 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
...
@@ -584,8 +666,59 @@ inline __host__ __device__ f4x32_t scaled_type_convert<f4x32_t, float32_t>(e8m0_
template
<
>
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f6_t
>
(
e8m0_bexp_t
scale
,
f6_t
x
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
return
utils
::
to_float
<
f6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The f6x32_t vector to be converted.
* @return The converted float vector representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f6x32_t
>
(
e8m0_bexp_t
scale
,
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
scale
,
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
/**
/**
...
@@ -599,8 +732,59 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
...
@@ -599,8 +732,59 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
template
<
>
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
inline
__host__
__device__
float
scaled_type_convert
<
float
,
bf6_t
>
(
e8m0_bexp_t
scale
,
bf6_t
x
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
scale
));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
return
utils
::
to_float
<
bf6_t
>
(
scale
,
x
);
#endif
}
/**
* @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats,
* applying the specified scaling factor.
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The bf6x32_t vector to be converted.
* @return The converted vector of 32 float representation of the input.
*/
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
bf6x32_t
>
(
e8m0_bexp_t
scale
,
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
scale
));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
scale
,
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
/**
/**
...
@@ -624,6 +808,28 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
...
@@ -624,6 +808,28 @@ inline __host__ __device__ f6_t scaled_type_convert<f6_t, float>(e8m0_bexp_t sca
#endif
#endif
}
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (f6_convert_sr) or round-to-nearest-even (f6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted vector of 6-bit floating-point values (f6x32_t).
*/
template
<
>
inline
__host__
__device__
f6x32_t
scaled_type_convert
<
f6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
/**
/**
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified
* scale.
* scale.
...
@@ -645,4 +851,27 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s
...
@@ -645,4 +851,27 @@ inline __host__ __device__ bf6_t scaled_type_convert<bf6_t, float>(e8m0_bexp_t s
#endif
#endif
}
}
/**
* @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t),
* applying the specified scale.
*
* Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding
* (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne).
*
* @param scale The exponent scale factor (e8m0_bexp_t).
* @param x The float vector to convert.
* @return The converted 6-bit floating-point vector (bf6x32_t).
*/
template
<
>
inline
__host__
__device__
bf6x32_t
scaled_type_convert
<
bf6x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
bf6_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
#endif // #if CK_USE_NATIVE_MX_SUPPORT
}
// namespace ck
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
24f6f4ab
...
@@ -1146,10 +1146,11 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
...
@@ -1146,10 +1146,11 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
float
scale
=
1.0
f
;
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
scale
,
0
);
#else
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
float2_t
ret
{
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
()),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
())};
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{})),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}))};
return
ret
;
return
ret
;
#endif
#endif
}
}
...
@@ -1285,103 +1286,103 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
...
@@ -1285,103 +1286,103 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
// TODO: pack in a loop
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
0
>
{}
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<>
(
Number
<
1
>
{}
));
return
float_values
.
float32_array
;
return
float_values
.
float32_array
;
#endif
#endif
...
@@ -1399,8 +1400,59 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
...
@@ -1399,8 +1400,59 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
*/
*/
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
inline
__host__
__device__
f6_t
f6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
in1
,
in2
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
return
utils
::
sat_convert_to_type
<
f6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* rounding to nearest / even to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
f6_vector
;
#endif
}
}
/**
/**
...
@@ -1417,15 +1469,75 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
...
@@ -1417,15 +1469,75 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
out
.
f6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
f6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
return
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a 32-element single-precision float array into a packed 6-bit representation.
*
* This function divides each input float by the provided scale value, then performs conversion with
* stochastic rounding to pack each element into 6 bits of precision.
*
* @param x A vector of 32 floats stored in float32_t.
* @param scale A scaling factor for each float before conversion.
* @return An f6x32_t object storing the compressed 6-bit representation.
*/
inline
__host__
__device__
f6x32_t
f6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
f6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
f6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
f6_vector
;
#endif
}
}
/**
/**
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* @brief Specializes the type conversion template for converting a float into the 6-bit float type
* (f6_t).
* (f6_t).
*
*
* Depending on the CK_USE_SR_F
4
_CONVERSION flag,
* Depending on the CK_USE_SR_F
6
_CONVERSION flag,
* the conversion uses stochastic rounding
* the conversion uses stochastic rounding
* or round-to-nearest-even.
* or round-to-nearest-even.
*
*
...
@@ -1435,7 +1547,28 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
...
@@ -1435,7 +1547,28 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
template
<
>
template
<
>
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
inline
__host__
__device__
f6_t
type_convert
<
f6_t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
#else
return
f6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 floats into the
* vector of 32 6-bit float types (f6x32_t).
*
* Depending on the CK_USE_SR_F6_CONVERSION flag,
* the conversion uses stochastic rounding
* or round-to-nearest-even.
*
* @param x Input float value to be converted.
* @return The converted f6x32_t vector.
*/
template
<
>
inline
__host__
__device__
f6x32_t
type_convert
<
f6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
f6_convert_sr
(
x
);
return
f6_convert_sr
(
x
);
#else
#else
return
f6_convert_rne
(
x
);
return
f6_convert_rne
(
x
);
...
@@ -1454,8 +1587,62 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
...
@@ -1454,8 +1587,62 @@ inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f6_t
>
(
f6_t
x
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
in
.
f6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
return
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
* @param x The vector of 32 6-bit float (f6x32_t) values to be converted.
* @return The corresponding float representation.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f6x32_t
>
(
f6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
f6x32_t
f6_vector
;
f6_t
f6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
f6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
/**
/**
...
@@ -1470,8 +1657,60 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
...
@@ -1470,8 +1657,60 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
*/
*/
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
inline
__host__
__device__
bf6_t
bf6_convert_rne
(
float
x
,
float
scale
=
1.0
f
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
float16_t
in1
{
x
};
float16_t
in2
{};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
in1
,
in2
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
return
utils
::
sat_convert_to_type
<
bf6_t
>
(
x
/
scale
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using
* round-to-nearest-even.
*
* Divides the input by the specified scale, then saturates and converts
* it to a 6-bit BF6 floating-point format.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
float16_t
*
in1
=
reinterpret_cast
<
float16_t
*>
(
&
x
);
float16_t
*
in2
=
reinterpret_cast
<
float16_t
*>
(
&
x
+
16
);
return
__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32
(
*
in1
,
*
in2
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
);
});
return
out
.
bf6_vector
;
#endif
}
}
/**
/**
...
@@ -1489,14 +1728,76 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
...
@@ -1489,14 +1728,76 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
out
.
bf6_vector
=
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
in
.
float_vector
,
rng
,
scale
);
return
out
.
bf6_array
[
0
];
#else
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
return
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
x
/
scale
,
rng
);
#endif
}
/**
* @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic
* rounding.
*
* Divides the input by the specified scale,
* and converts the result to a 6-bit BF6 floating-point
* format with stochastic rounding.
*
* @param x The float vector to be converted.
* @param scale The scaling factor applied to the input before conversion.
* @return The converted bf6x32_t vector.
*/
inline
__host__
__device__
bf6x32_t
bf6_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
float_values
{
x
};
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
float_values
.
float_array
[
0
]);
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32
(
x
,
rng
,
scale
);
#else
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
in
{
x
};
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
bf6_array
[
i
]
=
utils
::
sat_convert_to_type_sr
<
bf6_t
>
(
in
.
float_array
[
i
]
/
scale
,
rng
);
});
return
out
.
bf6_vector
;
#endif
}
}
/**
/**
* @brief Specializes float-to-bf6_t conversion.
* @brief Specializes float-to-bf6_t conversion.
*
*
* Uses stochastic rounding if CK_USE_SR_F
4
_CONVERSION is defined,
* Uses stochastic rounding if CK_USE_SR_F
6
_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
* otherwise uses round-to-nearest-even.
*
*
* @param x Input float value to convert.
* @param x Input float value to convert.
...
@@ -1505,7 +1806,26 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
...
@@ -1505,7 +1806,26 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
template
<
>
template
<
>
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf6_t
type_convert
<
bf6_t
,
float
>
(
float
x
)
{
{
#if CK_USE_SR_F4_CONVERSION
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
#else
return
bf6_convert_rne
(
x
);
#endif
}
/**
* @brief Specializes vector of 32 float-to-bf6_t conversion.
*
* Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined,
* otherwise uses round-to-nearest-even.
*
* @param x Input float vector to convert.
* @return Converted bf6x32_t vector.
*/
template
<
>
inline
__host__
__device__
bf6x32_t
type_convert
<
bf6x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F6_CONVERSION
return
bf6_convert_sr
(
x
);
return
bf6_convert_sr
(
x
);
#else
#else
return
bf6_convert_rne
(
x
);
return
bf6_convert_rne
(
x
);
...
@@ -1524,8 +1844,63 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
...
@@ -1524,8 +1844,63 @@ inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
template
<
>
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
bf6_t
>
(
bf6_t
x
)
{
{
// currently there is no native conversion instruction
#if defined(__gfx950__)
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
out
.
float_vector
=
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
in
.
bf6_vector
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
return
out
.
float_array
[
0
];
#else
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
return
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
x
);
#endif
}
/**
* @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to
* vector of 32 floats.
*
* Interprets the bf6x32_t value using the default scale factor of 1 and returns
* its floating-point representation.
*
* @param x The bf6x32_t value to convert.
* @return The float representation of the given vector.
*/
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
bf6x32_t
>
(
bf6x32_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6
(
x
,
type_convert
<
float
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
()));
#else
union
{
bf6x32_t
bf6_vector
;
bf6_t
bf6_array
[
32
];
}
in
{
x
};
union
{
float32_t
float_vector
;
float
float_array
[
32
];
}
out
{};
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
out
.
float_array
[
i
]
=
utils
::
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
in
.
bf6_array
[
i
]);
});
return
out
.
float_vector
;
#endif
}
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
...
...
test/data_type/test_bf6.cpp
View file @
24f6f4ab
...
@@ -9,6 +9,8 @@
...
@@ -9,6 +9,8 @@
using
ck
::
bf6_convert_rne
;
using
ck
::
bf6_convert_rne
;
using
ck
::
bf6_convert_sr
;
using
ck
::
bf6_convert_sr
;
using
ck
::
bf6_t
;
using
ck
::
bf6_t
;
using
ck
::
bf6x16_pk_t
;
using
ck
::
bf6x32_pk_t
;
using
ck
::
e8m0_bexp_t
;
using
ck
::
e8m0_bexp_t
;
using
ck
::
Number
;
using
ck
::
Number
;
using
ck
::
scaled_type_convert
;
using
ck
::
scaled_type_convert
;
...
@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic)
...
@@ -216,3 +218,171 @@ TEST(BF6, ScaledConvertFP32Stochastic)
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
min_scale
),
bf6_convert_sr
(
neg_float
)),
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
min_scale
),
bf6_convert_sr
(
neg_float
)),
abs_tol
);
abs_tol
);
}
}
TEST
(
BF6
,
TestSize
)
{
ASSERT_EQ
(
1
,
sizeof
(
bf6_t
));
ASSERT_EQ
(
12
,
sizeof
(
bf6x16_pk_t
));
ASSERT_EQ
(
24
,
sizeof
(
bf6x32_pk_t
));
ASSERT_EQ
(
16
,
sizeof
(
vector_type
<
bf6x16_pk_t
,
1
>
));
ASSERT_EQ
(
32
,
sizeof
(
vector_type
<
bf6x16_pk_t
,
2
>
));
ASSERT_EQ
(
32
,
sizeof
(
vector_type
<
bf6x32_pk_t
,
1
>
));
}
TEST
(
BF6
,
TestAlignment
)
{
ASSERT_EQ
(
1
,
alignof
(
bf6_t
));
ASSERT_EQ
(
4
,
alignof
(
bf6x16_pk_t
));
ASSERT_EQ
(
4
,
alignof
(
bf6x32_pk_t
));
ASSERT_EQ
(
16
,
alignof
(
vector_type
<
bf6x16_pk_t
,
1
>
));
ASSERT_EQ
(
32
,
alignof
(
vector_type
<
bf6x16_pk_t
,
2
>
));
ASSERT_EQ
(
32
,
alignof
(
vector_type
<
bf6x32_pk_t
,
1
>
));
}
// test vector of 1 bf6x16_pk_t, contains 16 bf6_t
TEST
(
BF6
,
TestAsType16x1
)
{
// test size
const
int
vector_size
=
1
;
const
int
packed_size
=
16
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
test_vec_t
test_vec
=
{
bf6_t
(
0b000000
),
bf6_t
(
0b100000
),
bf6_t
(
0b000001
),
bf6_t
(
0b100001
),
bf6_t
(
0b000010
),
bf6_t
(
0b100010
),
bf6_t
(
0b000011
),
bf6_t
(
0b100011
),
bf6_t
(
0b000100
),
bf6_t
(
0b100100
),
bf6_t
(
0b000101
),
bf6_t
(
0b100101
),
bf6_t
(
0b000110
),
bf6_t
(
0b100110
),
bf6_t
(
0b001011
),
bf6_t
(
0b101011
)};
// reference vector
vector_type
<
bf6x16_pk_t
,
vector_size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf6x16_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf6x16_pk_t
>()(
Number
<
i
>
{})
=
bf6x16_pk_t
{}.
pack
(
test_vec
);
});
// copy the vector
vector_type
<
bf6x16_pk_t
,
vector_size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf6x16_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
static_cast
<
bf6_t
>
(
test_vec
[
static_cast
<
int
>
(
i
)]));
});
}
// test vector of 2 bf6x16_pk_t, contains 32 bf6_t
TEST
(
BF6
,
TestAsType16x2
)
{
// test size
const
int
vector_size
=
2
;
const
int
packed_size
=
16
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
test_vec_t
test_vec
[
2
];
test_vec
[
0
]
=
{
bf6_t
(
0b000000
),
bf6_t
(
0b100000
),
bf6_t
(
0b000001
),
bf6_t
(
0b100001
),
bf6_t
(
0b000010
),
bf6_t
(
0b100010
),
bf6_t
(
0b000011
),
bf6_t
(
0b100011
),
bf6_t
(
0b000100
),
bf6_t
(
0b100100
),
bf6_t
(
0b000101
),
bf6_t
(
0b100101
),
bf6_t
(
0b000110
),
bf6_t
(
0b100110
),
bf6_t
(
0b001011
),
bf6_t
(
0b101011
)};
test_vec
[
1
]
=
{
bf6_t
(
0b010000
),
bf6_t
(
0b110000
),
bf6_t
(
0b010001
),
bf6_t
(
0b110001
),
bf6_t
(
0b010010
),
bf6_t
(
0b110010
),
bf6_t
(
0b010011
),
bf6_t
(
0b110011
),
bf6_t
(
0b010100
),
bf6_t
(
0b110100
),
bf6_t
(
0b010101
),
bf6_t
(
0b110101
),
bf6_t
(
0b010110
),
bf6_t
(
0b110110
),
bf6_t
(
0b011011
),
bf6_t
(
0b111011
)};
// reference vector
vector_type
<
bf6x16_pk_t
,
vector_size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
idx_vector
)
{
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
idx_element
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf6x16_pk_t
>()(
Number
<
idx_vector
>
{})
.
template
unpack
<
>(
Number
<
idx_element
>
{}),
0
);
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf6x16_pk_t
>()(
Number
<
i
>
{})
=
bf6x16_pk_t
{}.
pack
(
test_vec
[
i
]);
});
// copy the vector
vector_type
<
bf6x16_pk_t
,
vector_size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
idx_vector
)
{
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
idx_element
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf6x16_pk_t
>()(
Number
<
idx_vector
>
{})
.
template
unpack
<
>(
Number
<
idx_element
>
{}),
static_cast
<
bf6_t
>
(
test_vec
[
idx_vector
][
static_cast
<
int
>
(
idx_element
)]));
});
});
}
// test vector of 1 bf6x32_pk_t, contains 32 bf6_t
TEST
(
BF6
,
TestAsType32x1
)
{
// test size
const
int
vector_size
=
1
;
const
int
packed_size
=
32
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
test_vec_t
test_vec
=
{
bf6_t
(
0b000000
),
bf6_t
(
0b100000
),
bf6_t
(
0b000001
),
bf6_t
(
0b100001
),
bf6_t
(
0b000010
),
bf6_t
(
0b100010
),
bf6_t
(
0b000011
),
bf6_t
(
0b100011
),
bf6_t
(
0b000100
),
bf6_t
(
0b100100
),
bf6_t
(
0b000101
),
bf6_t
(
0b100101
),
bf6_t
(
0b000110
),
bf6_t
(
0b100110
),
bf6_t
(
0b001011
),
bf6_t
(
0b101011
),
bf6_t
(
0b010000
),
bf6_t
(
0b110000
),
bf6_t
(
0b010001
),
bf6_t
(
0b110001
),
bf6_t
(
0b010010
),
bf6_t
(
0b110010
),
bf6_t
(
0b010011
),
bf6_t
(
0b110011
),
bf6_t
(
0b010100
),
bf6_t
(
0b110100
),
bf6_t
(
0b010101
),
bf6_t
(
0b110101
),
bf6_t
(
0b010110
),
bf6_t
(
0b110110
),
bf6_t
(
0b011011
),
bf6_t
(
0b111011
)};
// reference vector
vector_type
<
bf6x32_pk_t
,
vector_size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
bf6x32_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
bf6x32_pk_t
>()(
Number
<
i
>
{})
=
bf6x32_pk_t
{}.
pack
(
test_vec
);
});
// copy the vector
vector_type
<
bf6x32_pk_t
,
vector_size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
bf6x32_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
static_cast
<
bf6_t
>
(
test_vec
[
static_cast
<
int
>
(
i
)]));
});
}
test/data_type/test_fp4.cpp
View file @
24f6f4ab
...
@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1)
...
@@ -235,8 +235,10 @@ TEST(FP4, TestAsType1)
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}),
0
);
});
});
// assign test values to the vector
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1)
...
@@ -247,9 +249,9 @@ TEST(FP4, TestAsType1)
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}
),
test_vec
.
at
(
i
));
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}
),
test_vec
.
at
(
i
+
1
));
test_vec
.
at
(
i
+
1
));
});
});
}
}
...
@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2)
...
@@ -267,8 +269,10 @@ TEST(FP4, TestAsType2)
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}),
0
);
});
});
// assign test values to the vector
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2)
...
@@ -279,9 +283,9 @@ TEST(FP4, TestAsType2)
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}
),
test_vec
.
at
(
i
));
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}
),
test_vec
.
at
(
i
+
1
));
test_vec
.
at
(
i
+
1
));
});
});
}
}
...
@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4)
...
@@ -303,8 +307,10 @@ TEST(FP4, TestAsType4)
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}),
0
);
});
});
// assign test values to the vector
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4)
...
@@ -315,9 +321,9 @@ TEST(FP4, TestAsType4)
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}
),
test_vec
.
at
(
i
));
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}
),
test_vec
.
at
(
i
+
1
));
test_vec
.
at
(
i
+
1
));
});
});
}
}
...
@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8)
...
@@ -347,8 +353,10 @@ TEST(FP4, TestAsType8)
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}),
0
);
});
});
// assign test values to the vector
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8)
...
@@ -359,9 +367,9 @@ TEST(FP4, TestAsType8)
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}
),
test_vec
.
at
(
i
));
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}
),
test_vec
.
at
(
i
+
1
));
test_vec
.
at
(
i
+
1
));
});
});
}
}
...
@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16)
...
@@ -387,8 +395,10 @@ TEST(FP4, TestAsType16)
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}),
0
);
});
});
// assign test values to the vector
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16)
...
@@ -399,9 +409,9 @@ TEST(FP4, TestAsType16)
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}
),
test_vec
.
at
(
i
));
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}
),
test_vec
.
at
(
i
+
1
));
test_vec
.
at
(
i
+
1
));
});
});
}
}
...
@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32)
...
@@ -438,8 +448,10 @@ TEST(FP4, TestAsType32)
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
vector_type
<
f4x2_pk_t
,
size
>
right_vec
;
// check default CTOR
// check default CTOR
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
0
);
ASSERT_EQ
(
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
0
);
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}),
0
);
ASSERT_EQ
(
right_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}),
0
);
});
});
// assign test values to the vector
// assign test values to the vector
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32)
...
@@ -450,9 +462,9 @@ TEST(FP4, TestAsType32)
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
vector_type
<
f4x2_pk_t
,
size
>
left_vec
{
right_vec
};
// check if values were copied correctly
// check if values were copied correctly
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
static_for
<
0
,
size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
0
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
0
>
{}
),
test_vec
.
at
(
i
));
test_vec
.
at
(
i
));
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
1
>(),
ASSERT_EQ
(
left_vec
.
template
AsType
<
f4x2_pk_t
>()(
Number
<
i
>
{}).
template
unpack
<
>(
Number
<
1
>
{}
),
test_vec
.
at
(
i
+
1
));
test_vec
.
at
(
i
+
1
));
});
});
}
}
test/data_type/test_fp6.cpp
View file @
24f6f4ab
...
@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t;
...
@@ -10,6 +10,8 @@ using ck::e8m0_bexp_t;
using
ck
::
f6_convert_rne
;
using
ck
::
f6_convert_rne
;
using
ck
::
f6_convert_sr
;
using
ck
::
f6_convert_sr
;
using
ck
::
f6_t
;
using
ck
::
f6_t
;
using
ck
::
f6x16_pk_t
;
using
ck
::
f6x32_pk_t
;
using
ck
::
Number
;
using
ck
::
Number
;
using
ck
::
scaled_type_convert
;
using
ck
::
scaled_type_convert
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
...
@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic)
...
@@ -215,3 +217,169 @@ TEST(FP6, ScaledConvertFP32Stochastic)
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
min_scale
),
f6_convert_sr
(
neg_float
)),
scaled_type_convert
<
float
>
(
e8m0_bexp_t
(
min_scale
),
f6_convert_sr
(
neg_float
)),
abs_tol
);
abs_tol
);
}
}
TEST
(
FP6
,
TestSize
)
{
ASSERT_EQ
(
1
,
sizeof
(
f6_t
));
ASSERT_EQ
(
12
,
sizeof
(
f6x16_pk_t
));
ASSERT_EQ
(
24
,
sizeof
(
f6x32_pk_t
));
ASSERT_EQ
(
16
,
sizeof
(
vector_type
<
f6x16_pk_t
,
1
>
));
ASSERT_EQ
(
32
,
sizeof
(
vector_type
<
f6x16_pk_t
,
2
>
));
ASSERT_EQ
(
32
,
sizeof
(
vector_type
<
f6x32_pk_t
,
1
>
));
}
TEST
(
FP6
,
TestAlignment
)
{
ASSERT_EQ
(
1
,
alignof
(
f6_t
));
ASSERT_EQ
(
4
,
alignof
(
f6x16_pk_t
));
ASSERT_EQ
(
4
,
alignof
(
f6x32_pk_t
));
ASSERT_EQ
(
16
,
alignof
(
vector_type
<
f6x16_pk_t
,
1
>
));
ASSERT_EQ
(
32
,
alignof
(
vector_type
<
f6x16_pk_t
,
2
>
));
ASSERT_EQ
(
32
,
alignof
(
vector_type
<
f6x32_pk_t
,
1
>
));
}
// test vector of 1 f6x16_pk_t, contains 16 f6_t
TEST
(
FP6
,
TestAsType16x1
)
{
// test size
const
int
vector_size
=
1
;
const
int
packed_size
=
16
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
test_vec_t
test_vec
=
{
f6_t
(
0b000000
),
f6_t
(
0b100000
),
f6_t
(
0b000001
),
f6_t
(
0b100001
),
f6_t
(
0b000010
),
f6_t
(
0b100010
),
f6_t
(
0b000011
),
f6_t
(
0b100011
),
f6_t
(
0b000100
),
f6_t
(
0b100100
),
f6_t
(
0b000101
),
f6_t
(
0b100101
),
f6_t
(
0b000110
),
f6_t
(
0b100110
),
f6_t
(
0b001011
),
f6_t
(
0b101011
)};
// reference vector
vector_type
<
f6x16_pk_t
,
vector_size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f6x16_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f6x16_pk_t
>()(
Number
<
i
>
{})
=
f6x16_pk_t
{}.
pack
(
test_vec
);
});
// copy the vector
vector_type
<
f6x16_pk_t
,
vector_size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f6x16_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
static_cast
<
f6_t
>
(
test_vec
[
static_cast
<
int
>
(
i
)]));
});
}
// test vector of 2 f6x16_pk_t, contains 32 f6_t
TEST
(
FP6
,
TestAsType16x2
)
{
// test size
const
int
vector_size
=
2
;
const
int
packed_size
=
16
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
test_vec_t
test_vec
[
2
];
test_vec
[
0
]
=
{
f6_t
(
0b000000
),
f6_t
(
0b100000
),
f6_t
(
0b000001
),
f6_t
(
0b100001
),
f6_t
(
0b000010
),
f6_t
(
0b100010
),
f6_t
(
0b000011
),
f6_t
(
0b100011
),
f6_t
(
0b000100
),
f6_t
(
0b100100
),
f6_t
(
0b000101
),
f6_t
(
0b100101
),
f6_t
(
0b000110
),
f6_t
(
0b100110
),
f6_t
(
0b001011
),
f6_t
(
0b101011
)};
test_vec
[
1
]
=
{
f6_t
(
0b010000
),
f6_t
(
0b110000
),
f6_t
(
0b010001
),
f6_t
(
0b110001
),
f6_t
(
0b010010
),
f6_t
(
0b110010
),
f6_t
(
0b010011
),
f6_t
(
0b110011
),
f6_t
(
0b010100
),
f6_t
(
0b110100
),
f6_t
(
0b010101
),
f6_t
(
0b110101
),
f6_t
(
0b010110
),
f6_t
(
0b110110
),
f6_t
(
0b011011
),
f6_t
(
0b111011
)};
// reference vector
vector_type
<
f6x16_pk_t
,
vector_size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
idx_vector
)
{
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
idx_element
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f6x16_pk_t
>()(
Number
<
idx_vector
>
{})
.
template
unpack
<
>(
Number
<
idx_element
>
{}),
0
);
});
});
// assign test values to the vector
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f6x16_pk_t
>()(
Number
<
i
>
{})
=
f6x16_pk_t
{}.
pack
(
test_vec
[
i
]);
});
// copy the vector
vector_type
<
f6x16_pk_t
,
vector_size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
idx_vector
)
{
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
idx_element
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f6x16_pk_t
>()(
Number
<
idx_vector
>
{})
.
template
unpack
<
>(
Number
<
idx_element
>
{}),
static_cast
<
f6_t
>
(
test_vec
[
idx_vector
][
static_cast
<
int
>
(
idx_element
)]));
});
});
}
// test vector of 1 f6x32_pk_t, contains 32 f6_t
TEST
(
FP6
,
TestAsType32x1
)
{
// test size
const
int
vector_size
=
1
;
const
int
packed_size
=
32
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
test_vec_t
test_vec
=
{
f6_t
(
0b000000
),
f6_t
(
0b100000
),
f6_t
(
0b000001
),
f6_t
(
0b100001
),
f6_t
(
0b000010
),
f6_t
(
0b100010
),
f6_t
(
0b000011
),
f6_t
(
0b100011
),
f6_t
(
0b000100
),
f6_t
(
0b100100
),
f6_t
(
0b000101
),
f6_t
(
0b100101
),
f6_t
(
0b000110
),
f6_t
(
0b100110
),
f6_t
(
0b001011
),
f6_t
(
0b101011
),
f6_t
(
0b010000
),
f6_t
(
0b110000
),
f6_t
(
0b010001
),
f6_t
(
0b110001
),
f6_t
(
0b010010
),
f6_t
(
0b110010
),
f6_t
(
0b010011
),
f6_t
(
0b110011
),
f6_t
(
0b010100
),
f6_t
(
0b110100
),
f6_t
(
0b010101
),
f6_t
(
0b110101
),
f6_t
(
0b010110
),
f6_t
(
0b110110
),
f6_t
(
0b011011
),
f6_t
(
0b111011
)};
// reference vector
vector_type
<
f6x32_pk_t
,
vector_size
>
right_vec
;
// check default CTOR
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
right_vec
.
template
AsType
<
f6x32_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
0
);
});
// assign test values to the vector
ck
::
static_for
<
0
,
vector_size
,
1
>
{}([
&
](
auto
i
)
{
right_vec
.
template
AsType
<
f6x32_pk_t
>()(
Number
<
i
>
{})
=
f6x32_pk_t
{}.
pack
(
test_vec
);
});
// copy the vector
vector_type
<
f6x32_pk_t
,
vector_size
>
left_vec
{
right_vec
};
// check if values were copied correctly
ck
::
static_for
<
0
,
packed_size
,
1
>
{}([
&
](
auto
i
)
{
ASSERT_EQ
(
left_vec
.
template
AsType
<
f6x32_pk_t
>()(
Number
<
0
>
{}).
template
unpack
<
>(
Number
<
i
>
{}),
static_cast
<
f6_t
>
(
test_vec
[
static_cast
<
int
>
(
i
)]));
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment