Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7572a691
Commit
7572a691
authored
Feb 15, 2025
by
coderfeli
Browse files
merge develop
parents
7796fc73
6b6fcd37
Changes
452
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3034 additions
and
332 deletions
+3034
-332
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+4
-2
include/ck/utility/blkgemmpipe_scheduler.hpp
include/ck/utility/blkgemmpipe_scheduler.hpp
+10
-2
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+3
-3
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1467
-292
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+2
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+21
-8
include/ck/utility/e8m0.hpp
include/ck/utility/e8m0.hpp
+80
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+18
-1
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+3
-1
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+3
-3
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+6
-6
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+6
-1
include/ck/utility/is_detected.hpp
include/ck/utility/is_detected.hpp
+9
-7
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+6
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+5
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+3
-3
include/ck/utility/mxf4_utils.hpp
include/ck/utility/mxf4_utils.hpp
+109
-0
include/ck/utility/mxf6_utils.hpp
include/ck/utility/mxf6_utils.hpp
+325
-0
include/ck/utility/mxf8_utils.hpp
include/ck/utility/mxf8_utils.hpp
+570
-0
include/ck/utility/mxfp_utils.hpp
include/ck/utility/mxfp_utils.hpp
+384
-0
No files found.
Too many changes to show.
To preserve performance only
452 of 452+
files are displayed.
Plain diff
Email patch
include/ck/utility/array.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
...
...
@@ -38,6 +38,8 @@ struct Array
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
__host__
__device__
constexpr
TData
*
begin
()
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
TData
*
end
()
{
return
&
mData
[
NSize
];
}
};
// empty Array
...
...
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
// make empty array
...
...
include/ck/utility/blkgemmpipe_scheduler.hpp
View file @
7572a691
...
...
@@ -103,14 +103,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL
);
printf
(
" A/B buffer load inst: %d, %d
\n
A/B LDS write inst: %d, %d
\n
A/B LDS read inst: "
"%d, %d
\n
C MFMA inst: %d
\n
"
,
"%d, %d
\n
C MFMA inst: %d
\n
"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d
\n
"
,
A_Buffer_Load_Inst_Num
,
B_Buffer_Load_Inst_Num
,
A_LDS_Write_Inst_Num
,
B_LDS_Write_Inst_Num
,
A_LDS_Read_Inst_Num
,
B_LDS_Read_Inst_Num
,
C_MFMA_Inst_Num
);
C_MFMA_Inst_Num
,
A_LDS_Read_Width
,
B_LDS_Read_Width
,
ALDSWriteWidth
,
BLDSWriteWidth
,
ABufferLoadWidth
,
BBufferLoadWidth
);
}
};
...
...
include/ck/utility/container_helper.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
...
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif
namespace
ck
{
#ifdef CK_CODE_GEN_RTC
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
f4_t
=
unsigned
_BitInt
(
4
);
using
f6_t
=
_BitInt
(
6
);
// e2m3 format
using
bf6_t
=
unsigned
_BitInt
(
6
);
// e3m2 format
struct
f4x2_pk_t
{
using
type
=
uint8_t
;
type
data
;
f4x2_pk_t
()
:
data
{
type
{}}
{}
f4x2_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
type
unpack
(
Number
<
I
>
)
const
{
static_assert
(
I
<
2
,
"Index is out of range."
);
if
constexpr
(
I
==
0
)
return
data
&
0b00001111
;
else
return
(
data
>>
4
);
}
__host__
__device__
inline
type
pack
(
const
type
x0
,
const
type
x1
)
{
return
(
x1
<<
4
)
|
(
x0
&
0b00001111
);
}
};
struct
f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
f6x16_pk_t
()
:
data
{
type
{}}
{}
f6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
f6x32_pk_t
()
:
data
{
type
{}}
{}
f6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
f6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
f6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
3
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
16
)));
bf6x16_pk_t
()
:
data
{
type
{}}
{}
bf6x16_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
16
,
"Index out of range for 16 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
3
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
struct
bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using
element_type
=
uint32_t
;
using
type
=
StaticallyIndexedArray_v2
<
element_type
,
6
>
;
type
data
;
typedef
int8_t
test_vec_t
__attribute__
((
ext_vector_type
(
32
)));
bf6x32_pk_t
()
:
data
{
type
{}}
{}
bf6x32_pk_t
(
type
init
)
:
data
{
init
}
{}
template
<
index_t
I
>
__host__
__device__
inline
bf6_t
unpack
(
Number
<
I
>
)
{
static_assert
(
I
<
32
,
"Index out of range for 32 f6_t elements."
);
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
I
*
num_bits_elem
;
constexpr
int
arr_idx
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
uint32_t
bits
=
data
.
At
(
Number
<
arr_idx
>
{})
>>
bit_offset
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
if
constexpr
(
overhang
>
0
&&
(
arr_idx
+
1
)
<
vector_size
)
{
bits
|=
(
data
.
At
(
Number
<
arr_idx
+
1
>
{})
&
((
1u
<<
overhang
)
-
1
))
<<
(
num_bits_elem
-
overhang
);
}
return
static_cast
<
bf6_t
>
(
bits
&
0x3F
);
}
__host__
__device__
inline
type
pack
(
const
test_vec_t
&
x
)
{
type
packed
{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck
::
static_for
<
0
,
32
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
bits
=
static_cast
<
uint32_t
>
(
x
[
static_cast
<
int
>
(
i
)])
&
0x3F
;
constexpr
int
num_bits_elem
=
6
;
constexpr
int
num_bits_vec_elem
=
32
;
constexpr
int
vector_size
=
6
;
constexpr
int
bit_pos
=
i
*
num_bits_elem
;
constexpr
int
arr_index
=
bit_pos
/
num_bits_vec_elem
;
constexpr
int
bit_offset
=
bit_pos
%
num_bits_vec_elem
;
constexpr
int
overhang
=
bit_offset
+
num_bits_elem
-
num_bits_vec_elem
;
uint32_t
old_value
=
packed
.
At
(
Number
<
arr_index
>
{});
// insert bits into the current 32-bit block
old_value
|=
(
bits
<<
bit_offset
);
packed
.
At
(
Number
<
arr_index
>
{})
=
old_value
;
// if it crosses into the next block, shift the remainder
if
constexpr
(
overhang
>
0
&&
(
arr_index
+
1
)
<
vector_size
)
{
uint32_t
next_value
=
packed
.
At
(
Number
<
arr_index
+
1
>
{});
next_value
|=
(
bits
>>
(
num_bits_elem
-
overhang
));
packed
.
At
(
Number
<
arr_index
+
1
>
{})
=
next_value
;
}
});
return
packed
;
}
};
// custom data type - pack int4 data
struct
pk_i4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
...
...
@@ -19,14 +331,15 @@ inline constexpr auto next_pow2(uint32_t x)
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool
// native types: bool
, f4_t, f6_t, bf6_t
template
<
typename
T
>
inline
constexpr
bool
is_native_type
()
{
return
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
bhalf_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
||
is_same
<
T
,
uint8_t
>::
value
||
is_same
<
T
,
f8_fnuz_t
>::
value
||
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
;
is_same
<
T
,
bf8_fnuz_t
>::
value
||
is_same
<
T
,
bool
>::
value
||
is_same
<
T
,
f4_t
>::
value
||
is_same
<
T
,
f6_t
>::
value
||
is_same
<
T
,
bf6_t
>::
value
;
}
// vector_type
...
...
@@ -165,6 +478,13 @@ struct scalar_type<int4_t>
};
#endif
template
<
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
struct
scalar_type
<
f8_fnuz_t
>
{
...
...
@@ -201,7 +521,7 @@ struct scalar_type<bool>
};
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
type
=
d1_t
;
...
...
@@ -237,7 +557,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -297,20 +617,20 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
3
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d
4
_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d
3
_t
__attribute__
((
ext_vector_type
(
3
)));
using
type
=
d
4
_t
;
using
type
=
d
3
_t
;
union
{
d
4
_t
d
4
_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x
4
_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x
2
_
;
StaticallyIndexedArray
<
d
4
_t
,
1
>
d
4
x1_
;
d
3
_t
d
3
_
;
StaticallyIndexedArray
<
d1_t
,
3
>
d1x
3
_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x
1
_
;
StaticallyIndexedArray
<
d
3
_t
,
1
>
d
3
x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -320,20 +640,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
3
_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
4
_
;
return
data_
.
d1x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
2
_
;
return
data_
.
d2x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d
4
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
3
_t
>::
value
)
{
return
data_
.
d
4
x1_
;
return
data_
.
d
3
x1_
;
}
else
{
...
...
@@ -344,20 +664,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
4
_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d
3
_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
4
_
;
return
data_
.
d1x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
2
_
;
return
data_
.
d2x
1
_
;
}
else
if
constexpr
(
is_same
<
X
,
d
4
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
3
_t
>::
value
)
{
return
data_
.
d
4
x1_
;
return
data_
.
d
3
x1_
;
}
else
{
...
...
@@ -367,22 +687,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
using
type
=
d
8
_t
;
using
type
=
d
4
_t
;
union
{
d8_t
d8_
;
StaticallyIndexedArray
<
d1_t
,
8
>
d1x8_
;
StaticallyIndexedArray
<
d2_t
,
4
>
d2x4_
;
StaticallyIndexedArray
<
d4_t
,
2
>
d4x2_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
d4_t
d4_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x4_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x2_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -392,25 +710,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
8
_
;
return
data_
.
d1x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
4
_
;
return
data_
.
d2x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
return
data_
.
d4x1_
;
}
else
{
...
...
@@ -421,25 +734,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
8
_
;
return
data_
.
d1x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
4
_
;
return
data_
.
d2x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
return
data_
.
d4x1_
;
}
else
{
...
...
@@ -449,24 +757,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
5
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d5_t
__attribute__
((
ext_vector_type
(
5
)));
using
type
=
d
16
_t
;
using
type
=
d
5
_t
;
union
{
d16_t
d16_
;
StaticallyIndexedArray
<
d1_t
,
16
>
d1x16_
;
StaticallyIndexedArray
<
d2_t
,
8
>
d2x8_
;
StaticallyIndexedArray
<
d4_t
,
4
>
d4x4_
;
StaticallyIndexedArray
<
d8_t
,
2
>
d8x2_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
d5_t
d5_
;
StaticallyIndexedArray
<
d1_t
,
5
>
d1x5_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d5_t
,
1
>
d5x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -476,30 +780,20 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d5_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
return
data_
.
d1x5_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
16
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
5
_t
>::
value
)
{
return
data_
.
d
16
x1_
;
return
data_
.
d
5
x1_
;
}
else
{
...
...
@@ -510,30 +804,20 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d5_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
return
data_
.
d1x5_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
16
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
5
_t
>::
value
)
{
return
data_
.
d
16
x1_
;
return
data_
.
d
5
x1_
;
}
else
{
...
...
@@ -543,26 +827,22 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
7
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d7_t
__attribute__
((
ext_vector_type
(
7
)));
using
type
=
d
32
_t
;
using
type
=
d
7
_t
;
union
{
d32_t
d32_
;
StaticallyIndexedArray
<
d1_t
,
32
>
d1x32_
;
StaticallyIndexedArray
<
d2_t
,
16
>
d2x16_
;
StaticallyIndexedArray
<
d4_t
,
8
>
d4x8_
;
StaticallyIndexedArray
<
d8_t
,
4
>
d8x4_
;
StaticallyIndexedArray
<
d16_t
,
2
>
d16x2_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
d7_t
d7_
;
StaticallyIndexedArray
<
d1_t
,
7
>
d1x7_
;
StaticallyIndexedArray
<
d2_t
,
3
>
d2x3_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
StaticallyIndexedArray
<
d7_t
,
1
>
d7x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -573,33 +853,24 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d7_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
32
_
;
return
data_
.
d1x
7
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
16
_
;
return
data_
.
d2x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
32
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
7
_t
>::
value
)
{
return
data_
.
d
32
x1_
;
return
data_
.
d
7
x1_
;
}
else
{
...
...
@@ -611,33 +882,24 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d7_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
32
_
;
return
data_
.
d1x
7
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
16
_
;
return
data_
.
d2x
3
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
return
data_
.
d4x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d
32
_t
>::
value
)
else
if
constexpr
(
is_same
<
X
,
d
7
_t
>::
value
)
{
return
data_
.
d
32
x1_
;
return
data_
.
d
7
x1_
;
}
else
{
...
...
@@ -647,28 +909,22 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
type
=
d
64
_t
;
using
type
=
d
8
_t
;
union
{
d64_t
d64_
;
StaticallyIndexedArray
<
d1_t
,
64
>
d1x64_
;
StaticallyIndexedArray
<
d2_t
,
32
>
d2x32_
;
StaticallyIndexedArray
<
d4_t
,
16
>
d4x16_
;
StaticallyIndexedArray
<
d8_t
,
8
>
d8x8_
;
StaticallyIndexedArray
<
d16_t
,
4
>
d16x4_
;
StaticallyIndexedArray
<
d32_t
,
2
>
d32x2_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
d8_t
d8_
;
StaticallyIndexedArray
<
d1_t
,
8
>
d1x8_
;
StaticallyIndexedArray
<
d2_t
,
4
>
d2x4_
;
StaticallyIndexedArray
<
d4_t
,
2
>
d4x2_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
...
...
@@ -679,34 +935,402 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x
64
_
;
return
data_
.
d1x
8
_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x
32
_
;
return
data_
.
d2x
4
_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x
16
_
;
return
data_
.
d4x
2
_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
return
data_
.
d8x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
else
{
return
data_
.
d32x2_
;
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
13
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d13_t
__attribute__
((
ext_vector_type
(
13
)));
using
type
=
d13_t
;
union
{
d13_t
d13_
;
StaticallyIndexedArray
<
d1_t
,
13
>
d1x13_
;
StaticallyIndexedArray
<
d4_t
,
3
>
d4x3_
;
StaticallyIndexedArray
<
d8_t
,
1
>
d8x1_
;
StaticallyIndexedArray
<
d13_t
,
1
>
d13x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d13_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x13_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d13_t
>::
value
)
{
return
data_
.
d13x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d13_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x13_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x3_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x1_
;
}
else
if
constexpr
(
is_same
<
X
,
d13_t
>::
value
)
{
return
data_
.
d13x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
using
type
=
d16_t
;
union
{
d16_t
d16_
;
StaticallyIndexedArray
<
d1_t
,
16
>
d1x16_
;
StaticallyIndexedArray
<
d2_t
,
8
>
d2x8_
;
StaticallyIndexedArray
<
d4_t
,
4
>
d4x4_
;
StaticallyIndexedArray
<
d8_t
,
2
>
d8x2_
;
StaticallyIndexedArray
<
d16_t
,
1
>
d16x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
using
type
=
d32_t
;
union
{
d32_t
d32_
;
StaticallyIndexedArray
<
d1_t
,
32
>
d1x32_
;
StaticallyIndexedArray
<
d2_t
,
16
>
d2x16_
;
StaticallyIndexedArray
<
d4_t
,
8
>
d4x8_
;
StaticallyIndexedArray
<
d8_t
,
4
>
d8x4_
;
StaticallyIndexedArray
<
d16_t
,
2
>
d16x2_
;
StaticallyIndexedArray
<
d32_t
,
1
>
d32x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T
d4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
T
d8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
T
d16_t
__attribute__
((
ext_vector_type
(
16
)));
typedef
T
d32_t
__attribute__
((
ext_vector_type
(
32
)));
typedef
T
d64_t
__attribute__
((
ext_vector_type
(
64
)));
using
type
=
d64_t
;
union
{
d64_t
d64_
;
StaticallyIndexedArray
<
d1_t
,
64
>
d1x64_
;
StaticallyIndexedArray
<
d2_t
,
32
>
d2x32_
;
StaticallyIndexedArray
<
d4_t
,
16
>
d4x16_
;
StaticallyIndexedArray
<
d8_t
,
8
>
d8x8_
;
StaticallyIndexedArray
<
d16_t
,
4
>
d16x4_
;
StaticallyIndexedArray
<
d32_t
,
2
>
d32x2_
;
StaticallyIndexedArray
<
d64_t
,
1
>
d64x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{
0
}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
||
is_same
<
X
,
d8_t
>::
value
||
is_same
<
X
,
d16_t
>::
value
||
is_same
<
X
,
d32_t
>::
value
||
is_same
<
X
,
d64_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x64_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x32_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x16_
;
}
else
if
constexpr
(
is_same
<
X
,
d8_t
>::
value
)
{
return
data_
.
d8x8_
;
}
else
if
constexpr
(
is_same
<
X
,
d16_t
>::
value
)
{
return
data_
.
d16x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d32_t
>::
value
)
{
return
data_
.
d32x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d64_t
>::
value
)
{
...
...
@@ -763,7 +1387,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
128
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
128
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -889,7 +1513,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
256
,
typename
std
::
enable_if_t
<
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
256
,
typename
ck
::
enable_if_t
<
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
typedef
T
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
@@ -1038,17 +1662,48 @@ struct nnvb_data_t_selector<f8_ocp_t>
{
using
type
=
f8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf8_ocp_t
>
{
using
type
=
bf8_ocp_t
::
data_type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x16_pk_t
>
{
using
type
=
f6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
f6x32_pk_t
>
{
using
type
=
f6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x16_pk_t
>
{
using
type
=
bf6x16_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
bf6x32_pk_t
>
{
using
type
=
bf6x32_pk_t
::
type
;
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
ck
::
enable_if_t
<
sizeof
(
T
)
==
1
||
sizeof
(
T
)
==
2
||
sizeof
(
T
)
==
4
||
sizeof
(
T
)
==
8
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on the size of T
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
...
...
@@ -1119,27 +1774,84 @@ struct non_native_vector_base<
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
{
return
data_
.
dxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
{
return
data_
.
dTxN
;
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
{
return
data_
.
dNx1
;
}
else
{
return
err
;
}
}
};
// implementation for f6x16 and f6x32
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
T
,
N
,
std
::
enable_if_t
<
sizeof
(
T
)
==
12
||
sizeof
(
T
)
==
24
>>
{
using
data_t
=
typename
nnvb_data_t_selector
<
T
>::
type
;
// select data_t based on declared base type
using
element_t
=
typename
T
::
element_type
;
// select element_t based on declared element type
static_assert
(
sizeof
(
T
)
==
sizeof
(
data_t
),
"non_native_vector_base storage size mismatch"
);
static
constexpr
size_t
size_factor
=
sizeof
(
data_t
)
/
sizeof
(
element_t
);
// f6x16: 12/4 = 3, f6x32: 24/4 = 6
using
data_v
=
element_t
__attribute__
((
ext_vector_type
(
N
*
size_factor
)));
using
type
=
non_native_vector_base
<
T
,
N
>
;
union
alignas
(
next_pow2
(
N
*
sizeof
(
T
)))
{
data_v
dN
;
// storage vector;
StaticallyIndexedArray
<
data_t
,
N
>
dxN
;
StaticallyIndexedArray
<
T
,
N
>
dTxN
;
StaticallyIndexedArray
<
data_v
,
1
>
dNx1
;
}
data_
;
__host__
__device__
constexpr
non_native_vector_base
(
data_t
a
)
:
data_
{
data_v
(
a
.
At
(
Number
<
0
>
{}))}
{
static_assert
(
is_same_v
<
X
,
data_t
>
||
is_same_v
<
X
,
T
>
||
is_same_v
<
X
,
data_v
>
,
"Something went wrong, please check src and dst types."
);
}
__host__
__device__
constexpr
non_native_vector_base
(
T
f
)
:
non_native_vector_base
(
bit_cast
<
data_t
>
(
f
))
{
}
__host__
__device__
constexpr
non_native_vector_base
()
:
non_native_vector_base
(
T
{}){};
__host__
__device__
constexpr
non_native_vector_base
(
data_v
v
)
:
data_
{
v
}
{}
if
constexpr
(
is_same_v
<
X
,
data_t
>
)
__host__
__device__
constexpr
operator
data_v
()
const
{
return
data_
.
dN
;
}
__host__
__device__
constexpr
operator
data_t
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
dxN
;
return
data_
.
dxN
[
Number
<
0
>
{}]
;
}
else
if
constexpr
(
is_same_v
<
X
,
T
>
)
else
{
return
data_
.
d
T
xN
;
return
data_
.
dxN
;
// XXX this should cause an error
}
else
if
constexpr
(
is_same_v
<
X
,
data_v
>
)
}
__host__
__device__
constexpr
operator
T
()
const
{
if
constexpr
(
N
==
1
)
{
return
data_
.
d
Nx1
;
return
data_
.
d
TxN
[
Number
<
0
>
{}]
;
}
else
{
return
err
;
return
data_
.
dTxN
;
// XXX this should cause an
err
or
}
}
};
...
...
@@ -1163,9 +1875,17 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static
constexpr
index_t
vector_size
=
N
;
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
pk_i4_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
pk_i4_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1216,7 +1936,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
2
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1279,7 +1999,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
4
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1352,7 +2072,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1437,7 +2157,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
16
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
16
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d1_nnv_t
=
non_native_vector_base
<
T
,
1
>
;
...
...
@@ -1532,7 +2252,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
32
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
32
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -1636,7 +2356,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
};
template
<
typename
T
>
struct
vector_type
<
T
,
64
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
64
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
using
d1_t
=
T
;
using
d2_t
=
non_native_vector_base
<
T
,
2
>
;
...
...
@@ -1751,140 +2471,371 @@ struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
}
};
using
int64_t
=
long
;
using
int64_t
=
long
;
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
using
float16_t
=
typename
vector_type
<
float
,
16
>::
type
;
using
float32_t
=
typename
vector_type
<
float
,
32
>::
type
;
using
float64_t
=
typename
vector_type
<
float
,
64
>::
type
;
// fp16
using
half2_t
=
typename
vector_type
<
half_t
,
2
>::
type
;
using
half4_t
=
typename
vector_type
<
half_t
,
4
>::
type
;
using
half8_t
=
typename
vector_type
<
half_t
,
8
>::
type
;
using
half16_t
=
typename
vector_type
<
half_t
,
16
>::
type
;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
using
bhalf8_t
=
typename
vector_type
<
bhalf_t
,
8
>::
type
;
using
bhalf16_t
=
typename
vector_type
<
bhalf_t
,
16
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x8_t
=
typename
vector_type
<
int32_t
,
8
>::
type
;
using
int32x16_t
=
typename
vector_type
<
int32_t
,
16
>::
type
;
using
int32x32_t
=
typename
vector_type
<
int32_t
,
32
>::
type
;
using
int32x64_t
=
typename
vector_type
<
int32_t
,
64
>::
type
;
// i8
using
int8x2_t
=
typename
vector_type
<
int8_t
,
2
>::
type
;
using
int8x4_t
=
typename
vector_type
<
int8_t
,
4
>::
type
;
using
int8x8_t
=
typename
vector_type
<
int8_t
,
8
>::
type
;
using
int8x16_t
=
typename
vector_type
<
int8_t
,
16
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
using
f8x2_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
2
>::
type
;
using
f8x4_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
4
>::
type
;
using
f8x8_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
8
>::
type
;
using
f8x16_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
16
>::
type
;
using
f8x32_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
32
>::
type
;
using
f8x64_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
64
>::
type
;
// bf8
using
bf8x2_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
2
>::
type
;
using
bf8x4_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
4
>::
type
;
using
bf8x8_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
8
>::
type
;
using
bf8x16_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
16
>::
type
;
using
bf8x32_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
32
>::
type
;
using
bf8x64_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
64
>::
type
;
// f8
using
f8x2_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
2
>::
type
;
using
f8x4_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
4
>::
type
;
using
f8x8_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
8
>::
type
;
using
f8x16_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
16
>::
type
;
using
f8x32_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
32
>::
type
;
using
f8x64_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
64
>::
type
;
// bf8
using
bf8x2_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
2
>::
type
;
using
bf8x4_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
4
>::
type
;
using
bf8x8_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
8
>::
type
;
using
bf8x16_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
16
>::
type
;
using
bf8x32_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
32
>::
type
;
using
bf8x64_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
64
>::
type
;
#if CK_FP8_TYPE_OCP
// f8
using
f8x2_t
=
f8x2_ocp_t
;
using
f8x4_t
=
f8x4_ocp_t
;
using
f8x8_t
=
f8x8_ocp_t
;
using
f8x16_t
=
f8x16_ocp_t
;
using
f8x32_t
=
f8x32_ocp_t
;
using
f8x64_t
=
f8x64_ocp_t
;
// bf8
using
bf8x2_t
=
bf8x2_ocp_t
;
using
bf8x4_t
=
bf8x4_ocp_t
;
using
bf8x8_t
=
bf8x8_ocp_t
;
using
bf8x16_t
=
bf8x16_ocp_t
;
using
bf8x32_t
=
bf8x32_ocp_t
;
using
bf8x64_t
=
bf8x64_ocp_t
;
#elif CK_FP8_TYPE_FNUZ
// f8
using
f8x2_t
=
f8x2_fnuz_t
;
using
f8x4_t
=
f8x4_fnuz_t
;
using
f8x8_t
=
f8x8_fnuz_t
;
using
f8x16_t
=
f8x16_fnuz_t
;
using
f8x32_t
=
f8x32_fnuz_t
;
using
f8x64_t
=
f8x64_fnuz_t
;
// bf8
using
bf8x2_t
=
bf8x2_fnuz_t
;
using
bf8x4_t
=
bf8x4_fnuz_t
;
using
bf8x8_t
=
bf8x8_fnuz_t
;
using
bf8x16_t
=
bf8x16_fnuz_t
;
using
bf8x32_t
=
bf8x32_fnuz_t
;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// f4
using
f4x2_t
=
typename
vector_type
<
f4x2_pk_t
,
1
>::
type
;
using
f4x4_t
=
typename
vector_type
<
f4x2_pk_t
,
2
>::
type
;
using
f4x8_t
=
typename
vector_type
<
f4x2_pk_t
,
4
>::
type
;
using
f4x16_t
=
typename
vector_type
<
f4x2_pk_t
,
8
>::
type
;
using
f4x32_t
=
typename
vector_type
<
f4x2_pk_t
,
16
>::
type
;
using
f4x64_t
=
typename
vector_type
<
f4x2_pk_t
,
32
>::
type
;
// f6
using
f6x16_t
=
typename
vector_type
<
f6x16_pk_t
,
1
>::
type
;
using
f6x32_t
=
typename
vector_type
<
f6x32_pk_t
,
1
>::
type
;
// bf6
using
bf6x16_t
=
typename
vector_type
<
bf6x16_pk_t
,
1
>::
type
;
using
bf6x32_t
=
typename
vector_type
<
bf6x32_pk_t
,
1
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
template
<
>
struct
NumericLimits
<
half_t
>
{
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
8
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
8
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
f8_fnuz_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_fnuz_t
Min
()
{
return
f8_fnuz_t
(
binary_min
);
}
// fp64
using
double2_t
=
typename
vector_type
<
double
,
2
>::
type
;
using
double4_t
=
typename
vector_type
<
double
,
4
>::
type
;
__host__
__device__
static
constexpr
f8_fnuz_t
Max
()
{
return
f8_fnuz_t
(
binary_max
);
}
// fp32
using
float2_t
=
typename
vector_type
<
float
,
2
>::
type
;
using
float4_t
=
typename
vector_type
<
float
,
4
>::
type
;
using
float8_t
=
typename
vector_type
<
float
,
8
>::
type
;
using
float16_t
=
typename
vector_type
<
float
,
16
>::
type
;
using
float32_t
=
typename
vector_type
<
float
,
32
>::
type
;
using
float64_t
=
typename
vector_type
<
float
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_fnuz_t
Lowest
()
{
return
f8_fnuz_t
(
binary_lowest
);
}
// fp16
using
half2_t
=
typename
vector_type
<
half_t
,
2
>::
type
;
using
half4_t
=
typename
vector_type
<
half_t
,
4
>::
type
;
using
half8_t
=
typename
vector_type
<
half_t
,
8
>::
type
;
using
half16_t
=
typename
vector_type
<
half_t
,
16
>::
type
;
using
half32_t
=
typename
vector_type
<
half_t
,
32
>::
type
;
using
half64_t
=
typename
vector_type
<
half_t
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_fnuz_t
QuietNaN
()
{
return
f8_fnuz_t
(
binary_qnan
);
}
};
// bfp16
using
bhalf2_t
=
typename
vector_type
<
bhalf_t
,
2
>::
type
;
using
bhalf4_t
=
typename
vector_type
<
bhalf_t
,
4
>::
type
;
using
bhalf8_t
=
typename
vector_type
<
bhalf_t
,
8
>::
type
;
using
bhalf16_t
=
typename
vector_type
<
bhalf_t
,
16
>::
type
;
using
bhalf32_t
=
typename
vector_type
<
bhalf_t
,
32
>::
type
;
using
bhalf64_t
=
typename
vector_type
<
bhalf_t
,
64
>::
type
;
template
<
>
struct
NumericLimits
<
bf8_fnuz_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
// i32
using
int32x2_t
=
typename
vector_type
<
int32_t
,
2
>::
type
;
using
int32x4_t
=
typename
vector_type
<
int32_t
,
4
>::
type
;
using
int32x8_t
=
typename
vector_type
<
int32_t
,
8
>::
type
;
using
int32x16_t
=
typename
vector_type
<
int32_t
,
16
>::
type
;
using
int32x32_t
=
typename
vector_type
<
int32_t
,
32
>::
type
;
using
int32x64_t
=
typename
vector_type
<
int32_t
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
Min
()
{
return
bf8_fnuz_t
(
binary_min
);
}
// i8
using
int8x2_t
=
typename
vector_type
<
int8_t
,
2
>::
type
;
using
int8x4_t
=
typename
vector_type
<
int8_t
,
4
>::
type
;
using
int8x8_t
=
typename
vector_type
<
int8_t
,
8
>::
type
;
using
int8x16_t
=
typename
vector_type
<
int8_t
,
16
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
Max
()
{
return
bf8_fnuz_t
(
binary_max
);
}
// f8
using
f8x2_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
2
>::
type
;
using
f8x4_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
4
>::
type
;
using
f8x8_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
8
>::
type
;
using
f8x16_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
16
>::
type
;
using
f8x32_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
32
>::
type
;
using
f8x64_fnuz_t
=
typename
vector_type
<
f8_fnuz_t
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
Lowest
()
{
return
bf8_fnuz_t
(
binary_lowest
);
}
// bf8
using
bf8x2_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
2
>::
type
;
using
bf8x4_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
4
>::
type
;
using
bf8x8_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
8
>::
type
;
using
bf8x16_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
16
>::
type
;
using
bf8x32_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
32
>::
type
;
using
bf8x64_fnuz_t
=
typename
vector_type
<
bf8_fnuz_t
,
64
>::
type
;
__host__
__device__
static
constexpr
bf8_fnuz_t
QuietNaN
()
{
return
bf8_fnuz_t
(
binary_qnan
);
}
};
// f8
using
f8x2_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
2
>::
type
;
using
f8x4_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
4
>::
type
;
using
f8x8_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
8
>::
type
;
using
f8x16_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
16
>::
type
;
using
f8x32_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
32
>::
type
;
using
f8x64_ocp_t
=
typename
vector_type
<
f8_ocp_t
,
64
>::
type
;
template
<
>
struct
NumericLimits
<
f8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000 = 2^-6
static
constexpr
uint8_t
binary_max
=
0x7E
;
// 0b01111110 = 448
static
constexpr
uint8_t
binary_lowest
=
0xFE
;
// 0b11111110 = -448
static
constexpr
uint8_t
binary_qnan
=
0x7F
;
// 0b01111111
// bf8
using
bf8x2_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
2
>::
type
;
using
bf8x4_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
4
>::
type
;
using
bf8x8_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
8
>::
type
;
using
bf8x16_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
16
>::
type
;
using
bf8x32_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
32
>::
type
;
using
bf8x64_ocp_t
=
typename
vector_type
<
bf8_ocp_t
,
64
>::
type
;
__host__
__device__
static
constexpr
f8_ocp_t
Min
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_min
);
}
#if CK_FP8_TYPE_OCP
// f8
using
f8x2_t
=
f8x2_ocp_t
;
using
f8x4_t
=
f8x4_ocp_t
;
using
f8x8_t
=
f8x8_ocp_t
;
using
f8x16_t
=
f8x16_ocp_t
;
using
f8x32_t
=
f8x32_ocp_t
;
using
f8x64_t
=
f8x64_ocp_t
;
__host__
__device__
static
constexpr
f8_ocp_t
Max
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_max
);
}
// bf8
using
bf8x2_t
=
bf8x2_ocp_t
;
using
bf8x4_t
=
bf8x4_ocp_t
;
using
bf8x8_t
=
bf8x8_ocp_t
;
using
bf8x16_t
=
bf8x16_ocp_t
;
using
bf8x32_t
=
bf8x32_ocp_t
;
using
bf8x64_t
=
bf8x64_ocp_t
;
#elif CK_FP8_TYPE_FNUZ
// f8
using
f8x2_t
=
f8x2_fnuz_t
;
using
f8x4_t
=
f8x4_fnuz_t
;
using
f8x8_t
=
f8x8_fnuz_t
;
using
f8x16_t
=
f8x16_fnuz_t
;
using
f8x32_t
=
f8x32_fnuz_t
;
using
f8x64_t
=
f8x64_fnuz_t
;
__host__
__device__
static
constexpr
f8_ocp_t
Lowest
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_lowest
);
}
// bf8
using
bf8x2_t
=
bf8x2_fnuz_t
;
using
bf8x4_t
=
bf8x4_fnuz_t
;
using
bf8x8_t
=
bf8x8_fnuz_t
;
using
bf8x16_t
=
bf8x16_fnuz_t
;
using
bf8x32_t
=
bf8x32_fnuz_t
;
using
bf8x64_t
=
bf8x64_fnuz_t
;
#endif
__host__
__device__
static
constexpr
f8_ocp_t
QuietNaN
()
{
return
bit_cast
<
f8_ocp_t
>
(
binary_qnan
);
}
};
// u8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
template
<
>
struct
NumericLimits
<
bf8_ocp_t
>
{
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100 = 2^-14
static
constexpr
uint8_t
binary_max
=
0x7B
;
// 0b01111011 = 57344
static
constexpr
uint8_t
binary_lowest
=
0xFB
;
// 0b11111011 = -57344
static
constexpr
uint8_t
binary_qnan
=
0x7D
;
// 0b01111101
__host__
__device__
static
constexpr
bf8_ocp_t
Min
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Max
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
Lowest
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
bf8_ocp_t
QuietNaN
()
{
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#else
template
<
typename
T
>
struct
NumericLimits
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
};
...
...
@@ -2008,6 +2959,119 @@ struct NumericLimits<bf8_ocp_t>
return
bit_cast
<
bf8_ocp_t
>
(
binary_qnan
);
}
};
#endif
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
f6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x07
;
// 0b000111
static
constexpr
float
data_max_normal_number
=
7.5
;
static
constexpr
float
data_min_subnormal_number
=
0.125
;
__host__
__device__
static
constexpr
f6_t
Min
()
{
return
f6_t
(
binary_min_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Max
()
{
return
f6_t
(
binary_max_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
Lowest
()
{
return
f6_t
(
binary_lowest_normal
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MinSubnorm
()
{
return
f6_t
(
binary_min_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
f6_t
MaxSubnorm
()
{
return
f6_t
(
binary_max_subnorm
&
0b111111
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
bf6_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x08
;
// 0b001000
static
constexpr
uint8_t
binary_max_normal
=
0x1F
;
// 0b011111
static
constexpr
uint8_t
binary_lowest_normal
=
0x3F
;
// 0b111111
static
constexpr
uint8_t
binary_min_subnorm
=
0x01
;
// 0b000001
static
constexpr
uint8_t
binary_max_subnorm
=
0x03
;
// 0b000011
static
constexpr
float
data_max_normal_number
=
28
;
static
constexpr
float
data_min_subnormal_number
=
0.0625
;
__host__
__device__
static
constexpr
bf6_t
Min
()
{
return
bf6_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Max
()
{
return
bf6_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
bf6_t
Lowest
()
{
return
bf6_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
bf6_t
MinSubnorm
()
{
return
bf6_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
bf6_t
MaxSubnorm
()
{
return
bf6_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_bexp_t
>
{
static
constexpr
e8m0_bexp_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_bexp_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_bexp_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_bexp_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_bexp_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_bexp_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_bexp_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_bexp_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_bexp_t
Min
()
{
return
e8m0_bexp_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Max
()
{
return
e8m0_bexp_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
QuietNaN
()
{
return
e8m0_bexp_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_1
()
{
return
e8m0_bexp_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_2
()
{
return
e8m0_bexp_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_3
()
{
return
e8m0_bexp_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_135
()
{
return
e8m0_bexp_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_bexp_t
Binary_142
()
{
return
e8m0_bexp_t
(
binary_142
);
}
};
template
<
typename
T
>
struct
NumericUtils
...
...
@@ -2028,6 +3092,7 @@ struct NumericUtils<float>
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint32_t
;
};
...
...
@@ -2045,9 +3110,19 @@ struct NumericUtils<half_t>
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
static
constexpr
bool
has_inf
=
true
;
using
bitwise_type
=
uint16_t
;
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template
<
>
struct
NumericUtils
<
f8_fnuz_t
>
{
...
...
@@ -2055,6 +3130,7 @@ struct NumericUtils<f8_fnuz_t>
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
8
;
// negative zero nan mode
// static constexpr int bias = 7; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
...
...
@@ -2064,6 +3140,7 @@ struct NumericUtils<bf8_fnuz_t>
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
16
;
// negative zero nan mode
// static constexpr int bias = 15; // ieee mode
static
constexpr
bool
has_inf
=
false
;
};
template
<
>
struct
NumericUtils
<
f8_ocp_t
>
...
...
@@ -2082,11 +3159,109 @@ struct NumericUtils<bf8_ocp_t>
};
template
<
>
struct
NumericUtils
<
bhalf_t
>
struct
NumericUtils
<
f4_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
1
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
10
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b0000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b1000
;
static
constexpr
uint8_t
one_mask
=
0b0010
;
static
constexpr
uint8_t
set_sign_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b0111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b1111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b0001
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b1001
;
static
constexpr
bool
has_inf
=
false
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
f6_t
>
{
static
constexpr
int
exp
=
2
;
static
constexpr
int
mant
=
3
;
static
constexpr
int
bias
=
1
;
static
constexpr
uint32_t
sr_shift
=
12
;
static
constexpr
int
unbiased_exp_min
=
0
;
static
constexpr
int
unbiased_exp_max
=
2
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
3
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000111
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100111
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
bf6_t
>
{
static
constexpr
int
exp
=
3
;
static
constexpr
int
mant
=
2
;
static
constexpr
int
bias
=
3
;
static
constexpr
uint32_t
sr_shift
=
11
;
static
constexpr
int
unbiased_exp_min
=
-
2
;
static
constexpr
int
unbiased_exp_max
=
4
;
static
constexpr
int
biased_exp_min
=
1
;
static
constexpr
int
biased_exp_max
=
7
;
static
constexpr
uint8_t
positive_zero_mask
=
0b000000
;
static
constexpr
uint8_t
negative_zero_mask
=
0b100000
;
static
constexpr
uint8_t
set_sign_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_positive_normal_mask
=
0b011111
;
static
constexpr
uint8_t
data_max_negative_normal_mask
=
0b111111
;
static
constexpr
uint8_t
data_max_positive_subnormal_mask
=
0b000011
;
static
constexpr
uint8_t
data_max_negative_subnormal_mask
=
0b100011
;
static
constexpr
bool
has_inf
=
false
;
static
constexpr
bool
has_nan
=
false
;
static
constexpr
bool
has_zero
=
true
;
using
bitwise_type
=
uint8_t
;
};
template
<
>
struct
NumericUtils
<
e8m0_bexp_t
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
7
;
static
constexpr
int
bias
=
128
;
// negative zero nan mode
// static constexpr int bias = 127; // ieee mode
static
constexpr
int
mant
=
0
;
static
constexpr
int
bias
=
127
;
static
constexpr
int
unbiased_exp_min
=
-
127
;
static
constexpr
int
unbiased_exp_max
=
127
;
static
constexpr
int
biased_exp_min
=
0
;
static
constexpr
int
biased_exp_max
=
254
;
using
bitwise_type
=
uint8_t
;
};
}
// namespace ck
include/ck/utility/debug.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
debug
{
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
7572a691
...
...
@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize
element_space_size_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
T
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
,
ElementSpaceSize
element_space_size
)
:
p_data_
{
p_data
},
element_space_size_
{
element_space_size
}
{
...
...
@@ -54,7 +61,8 @@ struct DynamicBuffer
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
...
...
@@ -81,14 +89,18 @@ struct DynamicBuffer
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
p_data_
,
i
,
is_valid_element
,
element_space_size_
,
invalid_element_value_
);
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
,
invalid_element_value_
);
}
}
else
...
...
@@ -190,12 +202,13 @@ struct DynamicBuffer
dst_buf
.
p_data_
,
dst_offset
,
is_valid_element
,
element_space_size_
);
element_space_size_
/
PackedSize
);
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
...
...
@@ -224,7 +237,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
if
constexpr
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
&&
is_same
<
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
,
int8_t
>::
value
&&
...
...
@@ -376,7 +389,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
{
...
...
@@ -415,7 +428,7 @@ struct DynamicBuffer
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
/
PackedSize
);
}
else
if
(
is_valid_element
)
{
...
...
include/ck/utility/e8m0.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace
ck
{
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct
e8m0_bexp_t
{
using
type
=
uint8_t
;
type
data
;
constexpr
static
type
bias
=
127
;
constexpr
static
type
nan_mask
=
0xFF
;
__host__
__device__
constexpr
e8m0_bexp_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
e8m0_bexp_t
(
int
init
)
:
data
{
static_cast
<
type
>
(
init
&
nan_mask
)}
{
}
__host__
__device__
explicit
constexpr
e8m0_bexp_t
(
float
scale
)
:
data
{
static_cast
<
type
>
((
bit_cast
<
uint32_t
>
(
scale
)
&
(
nan_mask
<<
23
))
>>
23
)}
{
}
__host__
__device__
explicit
constexpr
operator
float
()
const
{
if
(
data
==
nan_mask
||
data
==
0
)
{
uint32_t
bits
=
data
<<
1
;
bits
|=
1
;
bits
<<=
22
;
return
bit_cast
<
float
>
(
bits
);
}
else
{
uint32_t
bits
=
data
<<
23
;
return
bit_cast
<
float
>
(
bits
);
}
}
__host__
__device__
constexpr
bool
operator
==
(
const
e8m0_bexp_t
&
other
)
const
{
// strict IEEE compliance for NaN
return
data
==
other
.
data
&&
data
!=
nan_mask
;
}
__host__
__device__
constexpr
bool
is_nan
()
const
{
return
data
==
nan_mask
;
}
};
namespace
utils
{
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
);
template
<
>
__host__
__device__
inline
int
get_exponent_value
<
e8m0_bexp_t
>
(
e8m0_bexp_t
x
)
{
return
x
.
data
;
}
}
// namespace utils
}
// namespace ck
include/ck/utility/enable_if.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
#ifndef CK_CODE_GEN_RTC
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
include/ck/utility/env.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#pragma once
#include <cstdlib>
...
...
@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
}
// namespace ck
#endif
include/ck/utility/functional.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
return
ck
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
return
ck
::
forward
<
Y
>
(
y
);
}
}
...
...
include/ck/utility/functional4.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
...
...
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
}
};
...
...
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
std
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
ck
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
}
};
...
...
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
));
}
// TODO: properly implement unpack that takes any number of containers
...
...
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Y
>
(
y
));
}
}
// namespace ck
...
...
include/ck/utility/integral_constant.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
}
// namespace ck
include/ck/utility/is_detected.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
value_t
=
integral_constant
<
bool
,
false
>
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
struct
detector
<
Default
,
ck
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
value_t
=
integral_constant
<
bool
,
true
>
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
...
...
@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
template
<
typename
T
>
using
is_pack2_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack2_invocable
);
using
is_pack2_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack2_invocable
);
template
<
typename
T
>
using
is_pack4_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack4_invocable
);
using
is_pack4_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack4_invocable
);
template
<
typename
T
>
using
is_pack8_invocable_t
=
decltype
(
std
::
declval
<
T
&>
().
is_pack8_invocable
);
using
is_pack8_invocable_t
=
decltype
(
ck
::
declval
<
T
&>
().
is_pack8_invocable
);
}
// namespace ck
include/ck/utility/loop_scheduler.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#pragma once
...
...
@@ -25,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
}
// namespace ck
#ifndef CK_CODE_GEN_RTC
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
LoopScheduler
&
s
)
{
switch
(
s
)
...
...
@@ -35,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
}
return
os
;
}
#endif
include/ck/utility/magic_division.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,6 +9,10 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace
ck
{
// magic number division
...
...
include/ck/utility/math_v2.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -459,7 +459,7 @@ inline __host__ double expm1<double>(double x)
{
return
std
::
expm1
(
x
);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
...
...
include/ck/utility/mxf4_utils.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
template
<
>
__host__
__device__
inline
bool
is_nan
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
==
NumericLimits
<
e8m0_bexp_t
>::
QuietNaN
();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template
<
>
__host__
__device__
inline
bool
is_inf
<
f4_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f4_t
const
data
[[
maybe_unused
]])
{
// no inf representation for ocp_e2m1_mxfp4
return
false
;
}
template
<
>
__host__
__device__
inline
bool
is_zero
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f4_t
result
=
(
data
&
0b00001111
)
&
NumericUtils
<
f4_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
template
<
>
__host__
__device__
inline
float
to_float
<
f4_t
>
(
e8m0_bexp_t
const
scale
,
f4_t
const
data
)
{
if
(
is_nan
<
f4_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f4_t
>
(
scale
,
data
))
return
0.0
f
;
f4_t
prepared_data
=
data
&
0b00001111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f4_t
>
(
prepared_data
,
scale_exp
);
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type
<
f4_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type
<
f4_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
template
<
>
__host__
__device__
inline
f4_t
sat_convert_to_type_sr
<
f4_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f4_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
;
f4_t
res
=
convert_to_type_sr
<
f4_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f4_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f4_t
>::
DataMinSubnorm
())
return
value
<
0
?
NumericUtils
<
f4_t
>::
negative_zero_mask
:
NumericUtils
<
f4_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf6_utils.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace
ck
::
utils
{
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template
<
>
__host__
__device__
inline
bool
is_nan
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
dataBytes
[[
maybe_unused
]])
{
// no need to check for data as it does not have NaN representation
return
scale
.
is_nan
();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
f6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
f6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for fp6
return
false
;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template
<
>
__host__
__device__
inline
bool
is_inf
<
bf6_t
>
(
e8m0_bexp_t
const
scale
[[
maybe_unused
]],
bf6_t
const
data
[[
maybe_unused
]])
{
// no inf representation for bf6
return
false
;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
f6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
f6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template
<
>
__host__
__device__
inline
bool
is_zero
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
false
;
// no need to check for scale as it does not have a 0 representation
bf6_t
result
=
(
data
&
0b00111111
)
&
NumericUtils
<
bf6_t
>::
set_sign_mask
;
return
result
==
0b0
;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
f6_t
>
(
e8m0_bexp_t
const
scale
,
f6_t
const
data
)
{
if
(
is_nan
<
f6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
f6_t
>
(
scale
,
data
))
return
0.0
f
;
f6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
f6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template
<
>
__host__
__device__
inline
float
to_float
<
bf6_t
>
(
e8m0_bexp_t
const
scale
,
bf6_t
const
data
)
{
if
(
is_nan
<
bf6_t
>
(
scale
,
data
))
return
std
::
numeric_limits
<
float
>::
quiet_NaN
();
if
(
is_zero
<
bf6_t
>
(
scale
,
data
))
return
0.0
f
;
bf6_t
prepared_data
=
data
&
0b00111111
;
int
scale_exp
=
get_exponent_value
<
e8m0_bexp_t
>
(
scale
);
return
convert_to_float
<
bf6_t
>
(
prepared_data
,
scale_exp
);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type
<
f6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type
<
f6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type
<
bf6_t
>
(
float
value
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
{
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
}
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type
<
bf6_t
>
(
value
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
f6_t
sat_convert_to_type_sr
<
f6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
f6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
f6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
f6_t
>::
data_max_positive_normal_mask
;
f6_t
res
=
convert_to_type_sr
<
f6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
f6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
f6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
f6_t
>::
negative_zero_mask
:
NumericUtils
<
f6_t
>::
positive_zero_mask
;
return
res
;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template
<
>
__host__
__device__
inline
bf6_t
sat_convert_to_type_sr
<
bf6_t
>
(
float
value
,
uint32_t
seed
)
{
cvt
t
;
t
.
value_float
=
value
;
uint32_t
sign
=
t
.
value_bitwise
>>
31
;
if
(
std
::
isnan
(
value
))
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
bf6_t
>::
Max
())
// covers inf case as well
return
sign
?
NumericUtils
<
bf6_t
>::
data_max_negative_normal_mask
:
NumericUtils
<
bf6_t
>::
data_max_positive_normal_mask
;
bf6_t
res
=
convert_to_type_sr
<
bf6_t
>
(
value
,
seed
);
if
(
std
::
abs
(
to_float
<
bf6_t
>
(
NumericLimits
<
e8m0_bexp_t
>::
Binary_1
(),
res
))
<
NumericLimits
<
bf6_t
>::
DataMinSubnorm
())
return
sign
?
NumericUtils
<
bf6_t
>::
negative_zero_mask
:
NumericUtils
<
bf6_t
>::
positive_zero_mask
;
return
res
;
}
}
// namespace ck::utils
include/ck/utility/mxf8_utils.hpp
0 → 100644
View file @
7572a691
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace
ck
{
namespace
fp8_impl
{
#if CK_MX_FP8_CVT_FAST_PATH
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float
cast_to_f32_from_f8_scaled
(
float
scale
,
fp8_storage_t
v
)
{
union
{
unsigned
int
i32val
;
unsigned
char
i8val
[
4
];
}
val
;
val
.
i8val
[
0
]
=
v
;
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_f32_fp8
(
val
.
i32val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_f32_bf8
(
val
.
i32val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
>
static
__device__
float2_t
cast_to_f32x2_from_f8x2_scaled
(
float
scale
,
fp8x2_storage_t
v
)
{
const
auto
i16val
=
bit_cast
<
uint16_t
>
(
v
);
static_assert
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interpret
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_fp8
(
i16val
,
scale
,
0
);
}
else
{
return
__builtin_amdgcn_cvt_scalef32_pk_f32_bf8
(
i16val
,
scale
,
0
);
}
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8_storage_t
cast_to_f8_from_f32_scaled
(
float
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
fp8_storage_t
i8data
;
union
{
float
fval
;
unsigned
int
i32val
;
}
val
;
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
fp8_storage_t
v4i8
[
4
];
}
ret
{};
// unsigned int ival = 0;
val
.
fval
=
v
;
if
constexpr
(
stochastic_rounding
)
{
ret
.
ival
=
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
?
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
)
:
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
val
.
fval
,
rng
,
scale
,
0
);
i8data
=
ret
.
v4i8
[
0
];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
val
.
fval
,
val
.
fval
,
scale
,
/*dst_lo_hi_sel*/
false
);
}
i8data
=
ret
.
v4i8
[
0
];
}
return
i8data
;
}
template
<
ck_fp8_interpretation_t
interpret
,
bool
stochastic_rounding
=
false
>
static
__device__
fp8x2_storage_t
cast_to_f8_from_f32_scaled
(
float2_t
v
,
unsigned
int
rng
=
0
,
float
scale
=
1.0
f
)
{
union
{
uint32_t
ival
;
vector_type
<
int16_t
,
2
>::
type
v2i16
;
StaticallyIndexedArray
<
fp8x2_storage_t
,
2
>
v2f8x2
;
}
ret
{};
if
constexpr
(
stochastic_rounding
)
{
fp8x2_storage_t
f8x2
;
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_fp8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
else
{
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
0
],
rng
,
scale
,
0
);
f8x2
[
0
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
ret
.
ival
=
__builtin_amdgcn_cvt_scalef32_sr_bf8_f32
(
ret
.
ival
,
v
[
1
],
rng
,
scale
,
0
);
f8x2
[
1
]
=
ret
.
v2f8x2
(
Number
<
0
>
{})[
0
];
}
return
f8x2
;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if
constexpr
(
interpret
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
// If fval / scale > max fp8, returns Nan
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_fp8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
else
{
// If fval / scale > max bf8, returns Inf
ret
.
v2i16
=
__builtin_amdgcn_cvt_scalef32_pk_bf8_f32
(
/*old_vdst*/
ret
.
v2i16
,
v
[
0
],
v
[
1
],
scale
,
/*dst_lo_hi_sel*/
false
);
}
return
ret
.
v2f8x2
(
Number
<
0
>
{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
__is_interpret_supported
(
interp
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
return
cast_to_f8_from_f32_scaled
<
interp
,
stochastic_rounding
>
(
f
,
rng
,
scale
);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8_storage_t
cvt_float_to_fp8_scaled
(
const
float
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
/
scale
,
rng
);
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template
<
ck_fp8_interpretation_t
interp
,
bool
stochastic_rounding
=
false
>
__host__
__device__
static
inline
fp8x2_storage_t
cvt_float_to_fp8_scaled
(
const
float2_t
f
,
float
scale
)
{
static_assert
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
||
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
,
"Only OCP interpretations are supported"
);
uint32_t
rng
=
0
;
if
constexpr
(
stochastic_rounding
)
{
constexpr
int
seed
=
1254739
;
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
f
),
f
[
0
]);
}
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E4M3_OCP
)
{
return
{
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
3
,
4
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
if
constexpr
(
interp
==
ck_fp8_interpretation_t
::
CK_E5M2_OCP
)
{
return
{
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
0
]
/
scale
,
rng
),
cast_to_f8
<
float
,
2
,
5
,
false
,
true
,
stochastic_rounding
>
(
f
[
1
]
/
scale
,
rng
)};
}
else
{
__hip_assert
(
false
&&
"FP8 type is not supported by current target device"
);
return
0
;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
}
// namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_sr
(
X
x
,
float
scale
);
// Declare a template function for fp8 conversion using RNE
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
mxf8_convert_rne
(
X
x
,
float
scale
);
// convert fp32 to fp8 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_rne
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_rne
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_rne
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_rne
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_rne
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_rne
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_rne
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_rne
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_rne
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_rne
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_rne
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_rne
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_ocp_t
mxf8_convert_sr
<
f8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
f8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32 to bf8 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8_ocp_t
mxf8_convert_sr
<
bf8_ocp_t
,
float
>
(
float
x
,
float
scale
)
{
return
bf8_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x2_ocp_t
mxf8_convert_sr
<
f8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
f8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
f8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x2_ocp_t
mxf8_convert_sr
<
bf8x2_ocp_t
,
float2_t
>
(
float2_t
x
,
float
scale
)
{
return
bf8x2_ocp_t
{
fp8_impl
::
cvt_float_to_fp8_scaled
<
bf8_ocp_t
::
default_interpret
,
true
>
(
x
,
scale
)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x16_ocp_t
mxf8_convert_sr
<
f8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
f8x16_ocp_t
fp8_1x16
;
f8x2_ocp_t
fp8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_2x8
[
i
]
=
mxf8_convert_sr
<
f8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
fp8_1x16
;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x16_ocp_t
mxf8_convert_sr
<
bf8x16_ocp_t
,
float16_t
>
(
float16_t
x
,
float
scale
)
{
union
{
float16_t
float_1x16
;
float2_t
float_2x8
[
8
];
}
in
{
x
};
union
{
bf8x16_ocp_t
bf8_1x16
;
bf8x2_ocp_t
bf8_2x8
[
8
];
}
out
{};
ck
::
static_for
<
0
,
8
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_2x8
[
i
]
=
mxf8_convert_sr
<
bf8x2_ocp_t
>
(
in
.
float_2x8
[
i
],
scale
);
});
return
out
.
bf8_1x16
;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
f8x32_ocp_t
mxf8_convert_sr
<
f8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
f8x32_ocp_t
fp8_1x32
;
f8x16_ocp_t
fp8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
fp8_16x2
[
i
]
=
mxf8_convert_sr
<
f8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
fp8_1x32
;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template
<
>
inline
__host__
__device__
bf8x32_ocp_t
mxf8_convert_sr
<
bf8x32_ocp_t
,
float32_t
>
(
float32_t
x
,
float
scale
)
{
union
{
float32_t
float_1x32
;
float16_t
float_16x2
[
2
];
}
in
{
x
};
union
{
bf8x32_ocp_t
bf8_1x32
;
bf8x16_ocp_t
bf8_16x2
[
2
];
}
out
{};
ck
::
static_for
<
0
,
2
,
1
>
{}(
[
&
](
auto
i
)
{
out
.
bf8_16x2
[
i
]
=
mxf8_convert_sr
<
bf8x16_ocp_t
>
(
in
.
float_16x2
[
i
],
scale
);
});
return
out
.
bf8_1x32
;
}
}
// namespace ck
include/ck/utility/mxfp_utils.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
::
utils
{
union
cvt
{
float
value_float
;
uint32_t
value_bitwise
;
};
template
<
typename
DTYPE
>
inline
bool
getDataHasInf
()
{
return
DTYPE
::
dataInfo
.
hasInf
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_zero
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_nan
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
bool
is_inf
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
inline
int
get_exponent_value
(
T
x
)
{
x
>>=
NumericUtils
<
T
>::
mant
;
x
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
return
static_cast
<
int
>
(
x
);
}
template
<
typename
T
>
__host__
__device__
inline
bool
is_subnormal
(
T
x
)
{
return
get_exponent_value
<
T
>
(
x
)
==
0
;
}
template
<
typename
T
>
__host__
__device__
inline
double
get_mantissa_value
(
T
x
)
{
double
mantissa
=
is_subnormal
<
T
>
(
x
)
?
0.0
f
:
1.0
f
;
for
(
uint
i
=
0
;
i
<
NumericUtils
<
T
>::
mant
;
i
++
)
{
mantissa
+=
std
::
pow
(
2
,
-
int32_t
((
NumericUtils
<
T
>::
mant
-
i
)))
*
(
x
&
0b1
);
x
>>=
1
;
}
return
mantissa
;
}
template
<
typename
T
>
__host__
__device__
inline
bool
get_data_has_inf
()
{
return
NumericUtils
<
T
>::
has_inf
;
}
template
<
typename
T
>
__host__
__device__
float
convert_to_float
(
T
data
,
int
scale_exp
)
{
float
d_sign
=
std
::
pow
(
-
1
,
static_cast
<
float
>
(
data
>>
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
)));
float
d_exp
;
if
(
is_subnormal
<
T
>
(
data
))
d_exp
=
std
::
pow
(
2
,
1
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
else
d_exp
=
std
::
pow
(
2
,
get_exponent_value
<
T
>
(
data
)
-
static_cast
<
int
>
(
NumericUtils
<
T
>::
bias
));
float
d_mant
=
get_mantissa_value
<
T
>
(
data
);
float
data_value
=
d_sign
*
d_exp
*
d_mant
;
float
scale_value
=
std
::
pow
(
2
,
static_cast
<
float
>
((
scale_exp
-
static_cast
<
int
>
(
NumericUtils
<
e8m0_bexp_t
>::
bias
))));
return
data_value
*
scale_value
;
}
template
<
typename
T
>
__host__
__device__
inline
float
to_float
(
e8m0_bexp_t
const
scale
,
T
const
data
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type
(
float
value
);
template
<
typename
T
>
__host__
__device__
T
sat_convert_to_type_sr
(
float
value
,
uint32_t
seed
);
template
<
typename
T
>
inline
T
convert_to_type
(
float
value
)
{
using
bitwise_type
=
typename
NumericUtils
<
T
>::
bitwise_type
;
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint32_t
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
bitwise_type
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
bitwise_type
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
bitwise_type
mantissa
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
{
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
}
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
const
int
mfmt
=
NumericUtils
<
float
>::
mant
;
uint32_t
x
;
x
=
bit_cast
<
uint32_t
>
(
value
);
uint32_t
head
,
mantissa
;
int32_t
exponent
,
bias
;
uint32_t
sign
;
head
=
x
&
NumericUtils
<
float
>::
head_mask
;
mantissa
=
x
&
NumericUtils
<
float
>::
mant_mask
;
exponent
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
sign
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
bias
=
NumericUtils
<
float
>::
bias
;
if
(
x
==
0
)
{
return
0b0
;
}
const
int
mini_bias
=
NumericUtils
<
T
>::
bias
;
const
int
mini_denormal_act_exponent
=
1
-
mini_bias
;
int
act_exponent
,
out_exponent
,
exponent_diff
;
bool
is_subnorm
=
false
;
if
(
exponent
==
0
)
{
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
mini_denormal_act_exponent
)
{
exponent_diff
=
mini_denormal_act_exponent
-
act_exponent
;
is_subnorm
=
true
;
}
else
{
exponent_diff
=
0
;
}
mantissa
+=
(
1UL
<<
mfmt
);
}
auto
shift_amount
=
(
mfmt
-
NumericUtils
<
T
>::
mant
+
exponent_diff
);
shift_amount
=
(
shift_amount
>=
64
)
?
63
:
shift_amount
;
bool
midpoint
=
(
mantissa
&
((
1UL
<<
shift_amount
)
-
1
))
==
(
1UL
<<
(
shift_amount
-
1
));
float
min_subnorm
=
NumericLimits
<
T
>::
DataMinSubnorm
()
*
(
sign
?
-
1
:
1
);
if
(
is_subnorm
&&
std
::
abs
(
value
)
<
std
::
abs
(
min_subnorm
))
{
// closer to 0
if
(
std
::
abs
(
value
)
<=
std
::
abs
(
min_subnorm
-
value
))
return
0
;
else
return
1
|
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
));
}
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
out_exponent
=
(
act_exponent
+
exponent_diff
)
+
mini_bias
-
(
implicit_one
?
0
:
1
);
uint32_t
drop_mask
=
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
))
-
1
;
bool
odd
=
mantissa
&
(
1UL
<<
(
mfmt
-
NumericUtils
<
T
>::
mant
));
mantissa
+=
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
)
&
drop_mask
;
if
(
out_exponent
==
0
)
{
if
((
1UL
<<
mfmt
)
&
mantissa
)
{
out_exponent
=
1
;
}
}
else
{
if
((
1UL
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
out_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
NumericUtils
<
T
>::
mant
);
if
(
out_exponent
==
0
&&
mantissa
==
0
)
{
return
0
;
}
mantissa
&=
(
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
;
return
(
sign
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
|
(
out_exponent
<<
NumericUtils
<
T
>::
mant
)
|
mantissa
;
}
template
<
typename
T
>
inline
T
convert_to_type_sr
(
float
value
,
uint32_t
seed
)
{
if
(
std
::
abs
(
value
)
>
NumericLimits
<
T
>::
Max
())
{
float
max_value
=
NumericLimits
<
T
>::
Max
();
cvt
t
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
max_value
;
uint
max_bitwise
=
t
.
value_bitwise
;
// cppcheck-suppress redundantAssignment
t
.
value_float
=
value
;
T
sign
=
t
.
value_bitwise
>>
(
NumericUtils
<
float
>::
exp
+
NumericUtils
<
float
>::
mant
);
T
exp
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
)
-
(
NumericUtils
<
float
>::
bias
-
NumericUtils
<
T
>::
bias
);
uint32_t
mant_prev
=
max_bitwise
>>
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant_prev
&=
((
1UL
<<
NumericUtils
<
T
>::
mant
)
-
1
);
mant_prev
--
;
mant_prev
<<=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
uint32_t
prev_bit
=
((
max_bitwise
>>
NumericUtils
<
float
>::
mant
)
<<
NumericUtils
<
float
>::
mant
)
|
mant_prev
;
t
.
value_bitwise
=
prev_bit
;
float
prev_val
=
t
.
value_float
;
float
diff
=
max_value
-
prev_val
;
float
actual_max
=
max_value
+
(
diff
/
2
);
if
(
std
::
abs
(
value
)
<
actual_max
)
{
double
d_max_value
=
static_cast
<
double
>
(
max_value
);
double
d_actual_max
=
static_cast
<
double
>
(
actual_max
);
double
d_value
=
static_cast
<
double
>
(
value
);
double
d_is
=
std
::
abs
(
d_max_value
-
d_actual_max
);
double
d_seed
=
static_cast
<
double
>
(
seed
);
double
d_prob
=
1.0
f
-
(
std
::
abs
(
d_value
-
d_max_value
)
/
d_is
);
// prob to round down
double
thresh
=
UINT_MAX
*
d_prob
;
if
(
!
get_data_has_inf
<
T
>
()
||
d_seed
<=
thresh
)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return
sign
==
0
?
NumericUtils
<
f4_t
>::
data_max_positive_normal_mask
:
NumericUtils
<
f4_t
>::
data_max_negative_normal_mask
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
else
{
if
(
!
get_data_has_inf
<
T
>
())
return
(
1
<<
(
NumericUtils
<
T
>::
mant
+
NumericUtils
<
T
>::
exp
))
-
1
;
else
{
exp
++
;
return
sign
<<
((
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
))
// inf
|
(
exp
<<
NumericUtils
<
T
>::
mant
);
}
}
}
uint32_t
f32
=
bit_cast
<
uint32_t
>
(
value
);
auto
f32_mant
=
f32
&
NumericUtils
<
float
>::
mant_mask
;
auto
head
=
f32
&
NumericUtils
<
float
>::
head_mask
;
auto
f32_exp
=
(
head
>>
NumericUtils
<
float
>::
mant
)
&
NumericUtils
<
float
>::
exp_mask
;
auto
sign_bit
=
head
>>
(
NumericUtils
<
float
>::
mant
+
NumericUtils
<
float
>::
exp
);
auto
sign
=
sign_bit
<<
(
NumericUtils
<
T
>::
exp
+
NumericUtils
<
T
>::
mant
);
f32_exp
=
static_cast
<
int32_t
>
(
f32_exp
)
-
NumericUtils
<
float
>::
bias
;
int32_t
exp
=
f32_exp
;
auto
mant
=
f32_mant
;
bool
subnorm
=
false
;
if
(
f32
==
0
)
return
0b0
;
if
(
exp
>=
NumericUtils
<
T
>::
unbiased_exp_min
)
{
mant
=
f32_mant
;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else
if
(
exp
<
NumericUtils
<
T
>::
unbiased_exp_min
&&
NumericUtils
<
T
>::
exp
<
NumericUtils
<
float
>::
exp
)
{
subnorm
=
true
;
auto
diff
=
static_cast
<
uint32_t
>
(
NumericUtils
<
T
>::
unbiased_exp_min
-
exp
);
if
(
diff
>=
32
)
{
mant
=
0
;
f32_mant
=
0
;
}
else
{
f32_mant
|=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
;
f32_mant
>>=
diff
;
}
exp
=
0
;
mant
=
f32_mant
;
}
uint32_t
sr_shift
=
NumericUtils
<
T
>::
sr_shift
;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant
+=
seed
>>
sr_shift
;
// Increment exponent when mantissa overflows due to rounding
if
(
mant
>=
static_cast
<
uint32_t
>
(
1
)
<<
NumericUtils
<
float
>::
mant
)
++
exp
;
mant
>>=
(
NumericUtils
<
float
>::
mant
-
NumericUtils
<
T
>::
mant
);
mant
&=
((
1
<<
NumericUtils
<
T
>::
mant
)
-
1
);
auto
biased_exp
=
static_cast
<
uint32_t
>
(
exp
);
if
(
!
subnorm
)
biased_exp
=
static_cast
<
uint32_t
>
(
exp
+
NumericUtils
<
T
>::
bias
);
biased_exp
&=
((
1
<<
NumericUtils
<
T
>::
exp
)
-
1
);
auto
val
=
sign
|
biased_exp
<<
NumericUtils
<
T
>::
mant
|
mant
;
return
val
;
}
}
// namespace ck::utils
Prev
1
…
12
13
14
15
16
17
18
19
20
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment