Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ace76831
Unverified
Commit
ace76831
authored
Dec 12, 2024
by
Andriy Roshchenko
Committed by
GitHub
Dec 12, 2024
Browse files
Move scaled_type_convert functions to a separate header (#251)
parent
54eaae8f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
287 additions
and
279 deletions
+287
-279
include/ck/utility/scaled_type_convert.hpp
include/ck/utility/scaled_type_convert.hpp
+286
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+0
-279
test/data_type/test_fp4.cpp
test/data_type/test_fp4.cpp
+1
-0
No files found.
include/ck/utility/scaled_type_convert.hpp
0 → 100644
View file @
ace76831
#pragma once
#include "ck/utility/type_convert.hpp"
namespace
ck
{
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
()),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
())};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
}
// namespace ck
include/ck/utility/type_convert.hpp
View file @
ace76831
...
...
@@ -1353,285 +1353,6 @@ inline __host__ __device__ e8m0_bexp_t type_convert<e8m0_bexp_t, float>(float sc
return
utils
::
cast_from_float
(
scale
);
}
// Declare a template function for scaled conversion
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
scaled_type_convert
(
e8m0_bexp_t
scale
,
X
x
);
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_bexp_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
union
{
float
float_array
[
2
];
float2_t
float2_array
;
}
float_values
{};
float_values
.
float2_array
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
return
float_values
.
float_array
[
0
];
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_bexp_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{};
value
.
f4x2_array
[
0
]
=
x
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
()),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
())};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_bexp_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
bitwise_value
{};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
0
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
2
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
3
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
4
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
5
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
6
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
7
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
8
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
9
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
10
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
11
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
12
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
13
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
14
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
bitwise_value
.
f4x2_array
[
0
]
=
value
.
fp4x2
[
15
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
bitwise_value
.
bitwise
,
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
0
>
());
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
template
AsType
<
f4x2_pk_t
>()[
Number
<
0
>
{}].
unpack
<
1
>
());
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_bexp_t
scale
,
float
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_bexp_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_bexp_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
...
...
test/data_type/test_fp4.cpp
View file @
ace76831
...
...
@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/scaled_type_convert.hpp"
using
ck
::
e8m0_bexp_t
;
using
ck
::
f4_convert_rne
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment