Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a8cd34d6
Commit
a8cd34d6
authored
Nov 21, 2024
by
Rostyslav Geyyer
Browse files
Add vector conversions
parent
b5ac2abd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
759 additions
and
4 deletions
+759
-4
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+759
-4
No files found.
include/ck/utility/type_convert.hpp
View file @
a8cd34d6
...
...
@@ -524,6 +524,206 @@ inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with rne
inline
__host__
__device__
f4x2_t
f4_convert_rne
(
float2_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
fp4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
value
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
uint8_t
h
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with rne
inline
__host__
__device__
f4x32_t
f4_convert_rne
(
float32_t
x
,
float
scale
=
1.0
f
)
{
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{},
tmp_values
{};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
0
],
x
[
1
],
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
2
],
x
[
3
],
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
4
],
x
[
5
],
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
6
],
x
[
7
],
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
8
],
x
[
9
],
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
10
],
x
[
11
],
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
12
],
x
[
13
],
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
14
],
x
[
15
],
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
16
],
x
[
17
],
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
18
],
x
[
19
],
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
20
],
x
[
21
],
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
22
],
x
[
23
],
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
24
],
x
[
25
],
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
26
],
x
[
27
],
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
28
],
x
[
29
],
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32
(
tmp_values
.
bitwise
,
x
[
30
],
x
[
31
],
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
0
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
1
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
2
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
3
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
4
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
5
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
6
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
7
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
8
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
9
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
10
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
11
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
12
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
13
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
14
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
15
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
16
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
17
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
18
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
19
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
20
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
21
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
22
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
23
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
24
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
25
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
26
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
27
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
28
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
29
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
30
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type
<
f4_t
>
(
x
[
31
]
/
scale
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4 with stochastic rounding
inline
__host__
__device__
f4_t
f4_convert_sr
(
float
x
,
float
scale
=
1.0
f
)
{
...
...
@@ -542,6 +742,215 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4 with sr
inline
__host__
__device__
f4x2_t
f4_convert_sr
(
float2_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
value
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
x
,
rng
,
scale
,
0
);
return
value
.
f4x2_array
[
0
];
#else
union
{
uint32_t
bitwise
;
f4x2_t
f4x2_array
[
4
];
}
value
{
0
};
uint8_t
l
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
uint8_t
h
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
value
.
bitwise
=
(
h
<<
4
)
|
l
;
return
value
.
f4x2_array
[
0
];
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4 with sr
inline
__host__
__device__
f4x32_t
f4_convert_sr
(
float32_t
x
,
float
scale
=
1.0
f
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
[
0
]);
#if defined(__gfx950__)
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
},
tmp_values
{
0
};
union
{
float2_t
floatx2_array
[
16
];
float32_t
floatx32_array
;
}
float_values
{
0
};
// TODO: pack in a loop
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
0
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
0
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
1
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
1
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
2
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
2
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
3
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
3
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
4
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
4
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
5
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
5
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
6
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
6
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
7
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
7
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
8
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
8
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
9
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
9
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
10
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
10
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
11
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
11
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
12
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
12
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
13
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
13
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
14
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
14
]
=
tmp_values
.
f4x2_array
[
0
];
tmp_values
.
bitwise
=
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32
(
value
.
bitwise
,
float_values
.
floatx2_array
[
15
],
rng
,
scale
,
0
);
f4_values
.
f4x2_array
[
15
]
=
tmp_values
.
f4x2_array
[
0
];
return
f4_values
.
f4x32_array
;
#else
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
0
};
// TODO: pack in a loop
auto
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
0
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
1
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
2
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
3
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
4
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
5
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
6
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
7
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
8
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
9
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
10
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
11
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
12
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
13
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
14
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
15
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
16
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
17
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
18
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
19
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
20
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
21
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
22
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
23
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
24
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
25
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
26
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
27
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
28
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
29
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
30
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
tmp
=
utils
::
sat_convert_to_type_sr
<
f4_t
>
(
x
[
31
]
/
scale
,
rng
);
f4_values
.
bitwise
<<=
4
;
f4_values
.
bitwise
|=
tmp
;
return
f4_values
.
f4x32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
type_convert
<
f4_t
,
float
>
(
float
x
)
...
...
@@ -553,16 +962,204 @@ inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
type_convert
<
f4x2_t
,
float2_t
>
(
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
type_convert
<
f4x32_t
,
float32_t
>
(
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
);
#else
return
f4_convert_rne
(
x
);
#endif
}
// convert fp4 to fp32
template
<
>
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
data
)
inline
__host__
__device__
float
type_convert
<
float
,
f4_t
>
(
f4_t
x
)
{
#if defined(__gfx950__)
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
data
,
scale
,
0
)
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
)
.
template
AsType
<
float
>()(
Number
<
0
>
{});
#else
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
data
);
return
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
f4x2_t
>
(
f4x2_t
x
)
{
#if defined(__gfx950__)
float
scale
=
1.0
f
;
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
x
.
unpack
(
1
)),
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
x
.
unpack
(
0
))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
type_convert
<
float32_t
,
f4x32_t
>
(
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
float2_t
op
;
float32_t
ret
;
float
scale
=
1.0
f
;
// TODO: pack in a loop
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
0
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
1
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
2
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
3
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
4
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
5
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
6
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
7
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
8
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
9
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
10
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
11
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
12
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
13
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
14
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
15
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
0
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
1
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
2
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
3
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
4
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
5
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
6
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
7
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
8
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
9
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
10
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
11
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
12
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
13
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
14
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_scale_t
>::
Binary_1
(),
f4_values
.
f4x2_array
[
15
].
unpack
(
1
));
return
float_values
.
float32_array
;
#endif
}
...
...
@@ -587,13 +1184,147 @@ template <>
inline
__host__
__device__
float
scaled_type_convert
<
float
,
f4_t
>
(
e8m0_scale_t
scale
,
f4_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
scale
,
0
)
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
)
,
0
)
.
template
AsType
<
float
>()(
Number
<
0
>
{});
#else
return
utils
::
to_float
<
f4_t
>
(
scale
,
x
);
#endif
}
// convert vector of 2 fp4 to vector of 2 fp32
template
<
>
inline
__host__
__device__
float2_t
scaled_type_convert
<
float2_t
,
f4x2_t
>
(
e8m0_scale_t
scale
,
f4x2_t
x
)
{
#if defined(__gfx950__)
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
x
,
type_convert
<
float
>
(
scale
),
0
);
#else
float2_t
ret
{
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
unpack
(
1
)),
utils
::
to_float
<
f4_t
>
(
scale
,
x
.
unpack
(
0
))};
return
ret
;
#endif
}
// convert vector of 32 fp4 to vector of 32 fp32
template
<
>
inline
__host__
__device__
float32_t
scaled_type_convert
<
float32_t
,
f4x32_t
>
(
e8m0_scale_t
scale
,
f4x32_t
x
)
{
#if defined(__gfx950__)
union
{
f4x32_t
f4x32_array
;
f4x2_t
fp4x2
[
16
];
}
value
{
x
};
float2_t
op
;
float32_t
ret
;
// TODO: pack in a loop
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
0
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
0
]
=
op
[
0
];
ret
[
1
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
1
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
2
]
=
op
[
0
];
ret
[
3
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
2
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
4
]
=
op
[
0
];
ret
[
5
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
3
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
6
]
=
op
[
0
];
ret
[
7
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
4
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
8
]
=
op
[
0
];
ret
[
9
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
5
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
10
]
=
op
[
0
];
ret
[
11
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
6
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
12
]
=
op
[
0
];
ret
[
13
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
7
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
14
]
=
op
[
0
];
ret
[
15
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
8
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
16
]
=
op
[
0
];
ret
[
17
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
9
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
18
]
=
op
[
0
];
ret
[
19
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
10
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
20
]
=
op
[
0
];
ret
[
21
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
11
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
22
]
=
op
[
0
];
ret
[
23
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
12
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
24
]
=
op
[
0
];
ret
[
25
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
13
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
26
]
=
op
[
0
];
ret
[
27
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
14
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
28
]
=
op
[
0
];
ret
[
29
]
=
op
[
1
];
op
=
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4
(
value
.
fp4x2
[
15
],
type_convert
<
float
>
(
scale
),
0
);
ret
[
30
]
=
op
[
0
];
ret
[
31
]
=
op
[
1
];
return
ret
;
#else
union
{
float32_t
float32_array
;
float
float_array
[
32
];
}
float_values
{};
union
{
__uint128_t
bitwise
;
f4x2_t
f4x2_array
[
16
];
f4x32_t
f4x32_array
;
}
f4_values
{
bit_cast
<
__uint128_t
>
(
x
)};
// TODO: pack in a loop
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
0
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
1
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
2
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
3
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
4
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
5
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
6
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
7
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
8
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
9
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
10
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
11
].
unpack
(
1
));
float_values
.
float_array
[
0
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
unpack
(
0
));
float_values
.
float_array
[
1
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
12
].
unpack
(
1
));
float_values
.
float_array
[
2
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
unpack
(
0
));
float_values
.
float_array
[
3
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
13
].
unpack
(
1
));
float_values
.
float_array
[
4
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
unpack
(
0
));
float_values
.
float_array
[
5
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
14
].
unpack
(
1
));
float_values
.
float_array
[
6
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
unpack
(
0
));
float_values
.
float_array
[
7
]
=
utils
::
to_float
<
f4_t
>
(
scale
,
f4_values
.
f4x2_array
[
15
].
unpack
(
1
));
return
float_values
.
float32_array
;
#endif
}
// convert fp32 to fp4
template
<
>
inline
__host__
__device__
f4_t
scaled_type_convert
<
f4_t
,
float
>
(
e8m0_scale_t
scale
,
float
x
)
...
...
@@ -605,6 +1336,30 @@ inline __host__ __device__ f4_t scaled_type_convert<f4_t, float>(e8m0_scale_t sc
#endif
}
// convert vector of 2 fp32 to vector of 2 fp4
template
<
>
inline
__host__
__device__
f4x2_t
scaled_type_convert
<
f4x2_t
,
float2_t
>
(
e8m0_scale_t
scale
,
float2_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
// convert vector of 32 fp32 to vector of 32 fp4
template
<
>
inline
__host__
__device__
f4x32_t
scaled_type_convert
<
f4x32_t
,
float32_t
>
(
e8m0_scale_t
scale
,
float32_t
x
)
{
#if CK_USE_SR_F4_CONVERSION
return
f4_convert_sr
(
x
,
type_convert
<
float
>
(
scale
));
#else
return
f4_convert_rne
(
x
,
type_convert
<
float
>
(
scale
));
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment