Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7c7bd091
"docs/source/en/vscode:/vscode.git/clone" did not exist on "ae7cd5ad4c3ec9e13b5ad1b8d6f25b31696e78cf"
Commit
7c7bd091
authored
Apr 14, 2023
by
Rosty Geyyer
Browse files
Refactor TypeConvert as a struct
parent
4c6c750a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
270 additions
and
108 deletions
+270
-108
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+269
-106
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+1
-2
No files found.
include/ck/utility/data_type.hpp
View file @
7c7bd091
...
@@ -942,33 +942,21 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -942,33 +942,21 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
class
TypeConvert
template
<
typename
Y
,
typename
X
,
typename
...
config
>
struct
TypeConvert
{
{
public:
template
<
typename
X
>
// constructor
__host__
__device__
Y
operator
()(
X
&
x
)
const
__host__
__device__
TypeConvert
()
{
{
BF16ConvertRTN_
=
false
;
// use round to zero by default
}
// switch bf16 conversion mode to rtn
__host__
__device__
void
SetBF16ConvertRTN
()
{
BF16ConvertRTN_
=
true
;
}
// switch bf16 conversion mode to rtz
__host__
__device__
void
SetBF16ConvertRTZ
()
{
BF16ConvertRTN_
=
false
;
}
// convert for all types except bf16
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
};
// convert bfp16 to fp32
// convert bfp16 to fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
float
convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
struct
TypeConvert
<
float
,
bhalf_t
>
{
__host__
__device__
float
operator
()(
bhalf_t
&
x
)
const
{
{
union
union
{
{
...
@@ -978,125 +966,300 @@ class TypeConvert
...
@@ -978,125 +966,300 @@ class TypeConvert
return
u
.
fp32
;
return
u
.
fp32
;
}
}
};
// convert fp32 to bfp16
// convert fp32 to bfp16
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
float
>
(
float
x
)
struct
TypeConvert
<
bhalf_t
,
float
,
integral_constant
<
bool
,
true
>>
{
__host__
__device__
bhalf_t
operator
()(
float
&
x
)
const
{
{
// if using rtn
union
if
(
BF16ConvertRTN_
)
{
{
float
fp32
;
union
uint32_t
int32
;
{
}
u
=
{
x
};
float
fp32
;
uint32_t
int32
;
return
uint16_t
(
u
.
int32
>>
16
);
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
// if using rtz
else
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
return
uint16_t
(
u
.
int32
>>
16
);
}
}
}
};
// convert bfp16 to fp16 via fp32
// convert fp32 to bfp16
template
<
>
template
<
>
inline
__host__
__device__
constexpr
half_t
convert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
struct
TypeConvert
<
bhalf_t
,
float
,
integral_constant
<
bool
,
false
>>
{
__host__
__device__
bhalf_t
operator
()(
float
&
x
)
const
{
{
float
x_fp32
=
convert
<
float
>
(
x
);
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
};
// convert bfp16 to fp16 via fp32
template
<
>
struct
TypeConvert
<
half_t
,
bhalf_t
>
{
__host__
__device__
half_t
operator
()(
bhalf_t
&
x
)
const
{
float
x_fp32
=
TypeConvert
<
float
,
bhalf_t
>
{}(
x
);
return
static_cast
<
half_t
>
(
x_fp32
);
return
static_cast
<
half_t
>
(
x_fp32
);
}
}
};
// convert fp16 to bfp16 via fp32
// convert fp16 to bfp16 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
struct
TypeConvert
<
bhalf_t
,
half_t
>
{
__host__
__device__
bhalf_t
operator
()(
half_t
&
x
)
const
{
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
c
onvert
<
bhalf_t
>
(
x_fp32
);
return
TypeC
onvert
<
bhalf_t
,
float
>
{}
(
x_fp32
);
}
}
};
// convert bfp16 to int32 via fp32
// convert bfp16 to int32 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
int32_t
convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
struct
TypeConvert
<
int32_t
,
bhalf_t
>
{
__host__
__device__
int32_t
operator
()(
bhalf_t
&
x
)
const
{
{
float
x_fp32
=
c
onvert
<
float
>
(
x
);
float
x_fp32
=
TypeC
onvert
<
float
,
bhalf_t
>
{}
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
}
};
// convert int32 to bfp16 via fp32
// convert int32 to bfp16 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
struct
TypeConvert
<
bhalf_t
,
int32_t
>
{
__host__
__device__
bhalf_t
operator
()(
int32_t
&
x
)
const
{
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
c
onvert
<
bhalf_t
>
(
x_fp32
);
return
TypeC
onvert
<
bhalf_t
,
float
>
{}
(
x_fp32
);
}
}
};
// convert bfp16 to int8 via fp32
// convert bfp16 to int8 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
int8_t
convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
struct
TypeConvert
<
int8_t
,
bhalf_t
>
{
__host__
__device__
int8_t
operator
()(
bhalf_t
&
x
)
const
{
{
float
x_fp32
=
c
onvert
<
float
>
(
x
);
float
x_fp32
=
TypeC
onvert
<
float
,
bhalf_t
>
{}
(
x
);
return
static_cast
<
int8_t
>
(
x_fp32
);
return
static_cast
<
int8_t
>
(
x_fp32
);
}
}
};
// convert int8 to bfp16 via fp32
// convert int8 to bfp16 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
struct
TypeConvert
<
bhalf_t
,
int8_t
>
{
__host__
__device__
bhalf_t
operator
()(
int8_t
&
x
)
const
{
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
c
onvert
<
bhalf_t
>
(
x_fp32
);
return
TypeC
onvert
<
bhalf_t
,
float
>
{}
(
x_fp32
);
}
}
private:
bool
BF16ConvertRTN_
;
};
};
// class TypeConvert
// {
// public:
// // constructor
// __host__ __device__ TypeConvert()
// {
// BF16ConvertRTN_ = false; // use round to zero by default
// }
// // switch bf16 conversion mode to rtn
// __host__ __device__ void SetBF16ConvertRTN() { BF16ConvertRTN_ = true; }
// // switch bf16 conversion mode to rtz
// __host__ __device__ void SetBF16ConvertRTZ() { BF16ConvertRTN_ = false; }
// // convert for all types except bf16
// template <typename Y, typename X>
// __host__ __device__ constexpr Y convert(X x)
// {
// static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
// return static_cast<Y>(x);
// }
// // convert bfp16 to fp32
// template <>
// inline __host__ __device__ constexpr float convert<float, bhalf_t>(bhalf_t x)
// {
// union
// {
// uint32_t int32;
// float fp32;
// } u = {uint32_t(x) << 16};
// return u.fp32;
// }
// // convert fp32 to bfp16
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, float>(float x)
// {
// // if using rtn
// if(BF16ConvertRTN_)
// {
// union
// {
// float fp32;
// uint32_t int32;
// } u = {x};
// // When the exponent bits are not all 1s, then the value is zero, normal,
// // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// // This causes the bfloat16's mantissa to be incremented by 1 if the 16
// // least significant bits of the float mantissa are greater than 0x8000,
// // or if they are equal to 0x8000 and the least significant bit of the
// // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// // has the value 0x7f, then incrementing it causes it to become 0x00 and
// // the exponent is incremented by one, which is the next higher FP value
// // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// // incrementing it causes it to become an exponent of 0xFF and a mantissa
// // of 0x00, which is Inf, the next higher value to the unrounded value.
// bool flag0 = ~u.int32 & 0x7f800000;
// // When all of the exponent bits are 1, the value is Inf or NaN.
// // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// // mantissa bit. Quiet NaN is indicated by the most significant mantissa
// // bit being 1. Signaling NaN is indicated by the most significant
// // mantissa bit being 0 but some other bit(s) being 1. If any of the
// // lower 16 bits of the mantissa are 1, we set the least significant bit
// // of the bfloat16 mantissa, in order to preserve signaling NaN in case
// // the bfloat16's mantissa bits are all 0.
// bool flag1 = !flag0 && (u.int32 & 0xffff);
// u.int32 +=
// flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
// u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
// return uint16_t(u.int32 >> 16);
// }
// // if using rtz
// else
// {
// union
// {
// float fp32;
// uint32_t int32;
// } u = {x};
// return uint16_t(u.int32 >> 16);
// }
// }
// // convert bfp16 to fp16 via fp32
// template <>
// inline __host__ __device__ constexpr half_t convert<half_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<half_t>(x_fp32);
// }
// // convert fp16 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, half_t>(half_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// // convert bfp16 to int32 via fp32
// template <>
// inline __host__ __device__ constexpr int32_t convert<int32_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<int32_t>(x_fp32);
// }
// // convert int32 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int32_t>(int32_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// // convert bfp16 to int8 via fp32
// template <>
// inline __host__ __device__ constexpr int8_t convert<int8_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<int8_t>(x_fp32);
// }
// // convert int8 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int8_t>(int8_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// private:
// bool BF16ConvertRTN_;
// };
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
7c7bd091
...
@@ -271,8 +271,7 @@ struct Tensor
...
@@ -271,8 +271,7 @@ struct Tensor
Tensor
<
OutT
>
ret
(
mDesc
);
Tensor
<
OutT
>
ret
(
mDesc
);
ck
::
ranges
::
transform
(
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
ck
::
ranges
::
transform
(
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
ck
::
TypeConvert
type_convert
=
ck
::
TypeConvert
();
return
ck
::
TypeConvert
<
OutT
,
ck
::
remove_cvref_t
<
decltype
(
value
)
>>
{}(
value
);
return
type_convert
.
convert
<
OutT
>
(
value
);
});
});
return
ret
;
return
ret
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment