Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
aa1920da
Commit
aa1920da
authored
Nov 08, 2024
by
Rostyslav Geyyer
Browse files
Add fp4 vectors
parent
9433306a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
186 additions
and
56 deletions
+186
-56
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+186
-56
No files found.
include/ck/utility/data_type.hpp
View file @
aa1920da
...
@@ -1191,62 +1191,6 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1191,62 +1191,6 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
}
}
};
};
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_scale_t
>
{
static
constexpr
e8m0_scale_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_scale_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_scale_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_scale_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_scale_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_scale_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_scale_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_scale_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_scale_t
Min
()
{
return
e8m0_scale_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Max
()
{
return
e8m0_scale_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
QuietNaN
()
{
return
e8m0_scale_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_1
()
{
return
e8m0_scale_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_2
()
{
return
e8m0_scale_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_3
()
{
return
e8m0_scale_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_135
()
{
return
e8m0_scale_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_142
()
{
return
e8m0_scale_t
(
binary_142
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
8
,
typename
std
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
{
{
...
@@ -1643,6 +1587,136 @@ struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
...
@@ -1643,6 +1587,136 @@ struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
}
}
};
};
template
<
>
struct
vector_type
<
f4_t
,
2
,
typename
std
::
enable_if_t
<!
is_native_type
<
f4_t
>
()
>>
{
using
d1_t
=
f4_t
;
using
d2_t
=
uint8_t
;
using
type
=
d2_t
;
union
alignas
(
next_pow2
(
sizeof
(
type
)))
{
d2_t
d2_
;
StaticallyIndexedArray
<
d1_t
,
2
>
d1x2_
;
StaticallyIndexedArray
<
d2_t
,
1
>
d2x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
};
template
<
>
struct
vector_type
<
f4_t
,
4
,
typename
std
::
enable_if_t
<!
is_native_type
<
f4_t
>
()
>>
{
using
d1_t
=
f4_t
;
using
d2_t
=
uint8_t
;
using
d4_t
=
uint16_t
;
using
type
=
d4_t
;
union
alignas
(
next_pow2
(
sizeof
(
type
)))
{
d4_t
d4_
;
StaticallyIndexedArray
<
d1_t
,
4
>
d1x4_
;
StaticallyIndexedArray
<
d2_t
,
2
>
d2x2_
;
StaticallyIndexedArray
<
d4_t
,
1
>
d4x1_
;
}
data_
;
__host__
__device__
constexpr
vector_type
()
:
data_
{
type
{}}
{}
__host__
__device__
constexpr
vector_type
(
type
v
)
:
data_
{
v
}
{}
template
<
typename
X
>
__host__
__device__
constexpr
const
auto
&
AsType
()
const
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
__host__
__device__
constexpr
auto
&
AsType
()
{
static_assert
(
is_same
<
X
,
d1_t
>::
value
||
is_same
<
X
,
d2_t
>::
value
||
is_same
<
X
,
d4_t
>::
value
,
"Something went wrong, please check src and dst types."
);
if
constexpr
(
is_same
<
X
,
d1_t
>::
value
)
{
return
data_
.
d1x4_
;
}
else
if
constexpr
(
is_same
<
X
,
d2_t
>::
value
)
{
return
data_
.
d2x2_
;
}
else
if
constexpr
(
is_same
<
X
,
d4_t
>::
value
)
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
};
using
int64_t
=
long
;
using
int64_t
=
long
;
// fp64
// fp64
...
@@ -1805,6 +1879,62 @@ struct NumericLimits<bf8_t>
...
@@ -1805,6 +1879,62 @@ struct NumericLimits<bf8_t>
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
bf8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
};
template
<
>
struct
NumericLimits
<
f4_t
>
{
static
constexpr
uint8_t
binary_min_normal
=
0x2
;
// 0b0010
static
constexpr
uint8_t
binary_max_normal
=
0x7
;
// 0b0111
static
constexpr
uint8_t
binary_lowest_normal
=
0xF
;
// 0b1111
static
constexpr
uint8_t
binary_min_subnorm
=
0x1
;
// 0b0001
static
constexpr
uint8_t
binary_max_subnorm
=
0x1
;
// 0b0001
static
constexpr
float
data_max_normal_number
=
6
;
static
constexpr
float
data_min_subnormal_number
=
0.5
;
__host__
__device__
static
constexpr
f4_t
Min
()
{
return
f4_t
(
binary_min_normal
);
}
__host__
__device__
static
constexpr
f4_t
Max
()
{
return
f4_t
(
binary_max_normal
);
}
__host__
__device__
static
constexpr
f4_t
Lowest
()
{
return
f4_t
(
binary_lowest_normal
);
}
__host__
__device__
static
constexpr
f4_t
MinSubnorm
()
{
return
f4_t
(
binary_min_subnorm
);
}
__host__
__device__
static
constexpr
f4_t
MaxSubnorm
()
{
return
f4_t
(
binary_max_subnorm
);
}
__host__
__device__
static
constexpr
float
DataMaxNorm
()
{
return
data_max_normal_number
;
}
__host__
__device__
static
constexpr
float
DataMinSubnorm
()
{
return
data_min_subnormal_number
;
}
};
template
<
>
struct
NumericLimits
<
e8m0_scale_t
>
{
static
constexpr
e8m0_scale_t
binary_min
=
0x00
;
// 0b00000000
static
constexpr
e8m0_scale_t
binary_max
=
0xFE
;
// 0b11111110
static
constexpr
e8m0_scale_t
binary_qnan
=
0xFF
;
// 0b11111111
static
constexpr
e8m0_scale_t
binary_1
=
0x7F
;
// 0b01111111
static
constexpr
e8m0_scale_t
binary_2
=
0x80
;
// 0b10000000
static
constexpr
e8m0_scale_t
binary_3
=
0x82
;
// 0b10000010
static
constexpr
e8m0_scale_t
binary_135
=
0x87
;
// 0b10000111
static
constexpr
e8m0_scale_t
binary_142
=
0x8E
;
// 0b10001110
__host__
__device__
static
constexpr
e8m0_scale_t
Min
()
{
return
e8m0_scale_t
(
binary_min
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Max
()
{
return
e8m0_scale_t
(
binary_max
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
QuietNaN
()
{
return
e8m0_scale_t
(
binary_qnan
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_1
()
{
return
e8m0_scale_t
(
binary_1
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_2
()
{
return
e8m0_scale_t
(
binary_2
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_3
()
{
return
e8m0_scale_t
(
binary_3
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_135
()
{
return
e8m0_scale_t
(
binary_135
);
}
__host__
__device__
static
constexpr
e8m0_scale_t
Binary_142
()
{
return
e8m0_scale_t
(
binary_142
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
NumericUtils
struct
NumericUtils
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment