Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
71a7ac8b
Commit
71a7ac8b
authored
Apr 17, 2023
by
Rosty Geyyer
Browse files
Get back to template functions type_convert
parent
7c7bd091
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
284 deletions
+83
-284
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+77
-275
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+5
-8
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+1
-1
No files found.
include/ck/utility/data_type.hpp
View file @
71a7ac8b
...
@@ -942,323 +942,125 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -942,323 +942,125 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
template
<
typename
Y
,
typename
X
,
typename
...
config
>
// Convert X to Y
struct
TypeConvert
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
template
<
typename
X
>
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
__host__
__device__
Y
operator
()(
X
&
x
)
const
{
return
static_cast
<
Y
>
(
x
);
}
};
// convert bfp16 to fp32
return
static_cast
<
Y
>
(
x
);
template
<
>
}
struct
TypeConvert
<
float
,
bhalf_t
>
{
__host__
__device__
float
operator
()(
bhalf_t
&
x
)
const
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
}
};
// convert fp
32
to
b
fp
16
// convert
b
fp
16
to fp
32
template
<
>
template
<
>
struct
T
ype
C
onvert
<
bhalf_t
,
float
,
integral_constant
<
bool
,
true
>>
inline
__host__
__device__
constexpr
float
t
ype
_c
onvert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
{
{
__host__
__device__
bhalf_t
operator
()(
float
&
x
)
const
union
{
{
union
uint32_t
int32
;
{
float
fp32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
uint32_t
int32
;
}
u
=
{
x
};
return
uint16_t
(
u
.
int32
>>
16
);
return
u
.
fp32
;
}
}
};
// convert fp32 to bfp16
// convert fp32 to bfp16
template
<
>
template
<
>
struct
T
ype
C
onvert
<
bhalf_t
,
float
,
integral_constant
<
bool
,
false
>>
inline
__host__
__device__
constexpr
bhalf_t
t
ype
_c
onvert
<
bhalf_t
,
float
>
(
float
x
)
{
{
__host__
__device__
bhalf_t
operator
()(
float
&
x
)
const
union
{
{
union
float
fp32
;
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
x
};
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// When the exponent bits are not all 1s, then the value is zero, normal,
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// least significant bits of the float mantissa are greater than 0x8000,
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// or if they are equal to 0x8000 and the least significant bit of the
// least significant bits of the float mantissa are greater than 0x8000,
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// or if they are equal to 0x8000 and the least significant bit of the
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// the exponent is incremented by one, which is the next higher FP value
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// the exponent is incremented by one, which is the next higher FP value
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// of 0x00, which is Inf, the next higher value to the unrounded value.
// incrementing it causes it to become an exponent of 0xFF and a mantissa
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// When all of the exponent bits are 1, the value is Inf or NaN.
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// bit being 1. Signaling NaN is indicated by the most significant
// lower 16 bits of the mantissa are 1, we set the least significant bit
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// lower 16 bits of the mantissa are 1, we set the least significant bit
// the bfloat16's mantissa bits are all 0.
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
return
uint16_t
(
u
.
int32
>>
16
);
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
}
return
uint16_t
(
u
.
int32
>>
16
);
}
};
// convert bfp16 to fp16 via fp32
// convert bfp16 to fp16 via fp32
template
<
>
template
<
>
struc
t
T
ype
C
onvert
<
half_t
,
bhalf_t
>
inline
__host__
__device__
constexpr
half_
t
t
ype
_c
onvert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
{
{
__host__
__device__
half_t
operator
()(
bhalf_t
&
x
)
const
float
x_fp32
=
type_convert
<
float
>
(
x
);
{
float
x_fp32
=
TypeConvert
<
float
,
bhalf_t
>
{}(
x
);
return
static_cast
<
half_t
>
(
x_fp32
);
return
static_cast
<
half_t
>
(
x_fp32
);
}
}
};
// convert fp16 to bfp16 via fp32
// convert fp16 to bfp16 via fp32
template
<
>
template
<
>
struc
t
T
ype
C
onvert
<
bhalf_t
,
half_t
>
inline
__host__
__device__
constexpr
bhalf_
t
t
ype
_c
onvert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
{
__host__
__device__
bhalf_t
operator
()(
half_t
&
x
)
const
float
x_fp32
=
static_cast
<
float
>
(
x
);
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
TypeConvert
<
bhalf_t
,
float
>
{}(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
};
// convert bfp16 to int32 via fp32
// convert bfp16 to int32 via fp32
template
<
>
template
<
>
struct
T
ype
C
onvert
<
int32_t
,
bhalf_t
>
inline
__host__
__device__
constexpr
int32_t
t
ype
_c
onvert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
{
__host__
__device__
int32_t
operator
()(
bhalf_t
&
x
)
const
float
x_fp32
=
type_convert
<
float
>
(
x
);
{
float
x_fp32
=
TypeConvert
<
float
,
bhalf_t
>
{}(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
}
};
// convert int32 to bfp16 via fp32
// convert int32 to bfp16 via fp32
template
<
>
template
<
>
struct
T
ype
C
onvert
<
bhalf_t
,
int32_t
>
inline
__host__
__device__
constexpr
bhalf_t
t
ype
_c
onvert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
{
__host__
__device__
bhalf_t
operator
()(
int32_t
&
x
)
const
float
x_fp32
=
static_cast
<
float
>
(
x
);
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
TypeConvert
<
bhalf_t
,
float
>
{}(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
};
// convert bfp16 to int8 via fp32
// convert bfp16 to int8 via fp32
template
<
>
template
<
>
struc
t
T
ype
C
onvert
<
int8_t
,
bhalf_t
>
inline
__host__
__device__
constexpr
int8_
t
t
ype
_c
onvert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
{
{
__host__
__device__
int8_t
operator
()(
bhalf_t
&
x
)
const
float
x_fp32
=
type_convert
<
float
>
(
x
);
{
float
x_fp32
=
TypeConvert
<
float
,
bhalf_t
>
{}(
x
);
return
static_cast
<
int8_t
>
(
x_fp32
);
return
static_cast
<
int8_t
>
(
x_fp32
);
}
}
};
// convert int8 to bfp16 via fp32
// convert int8 to bfp16 via fp32
template
<
>
template
<
>
struc
t
T
ype
C
onvert
<
bhalf_t
,
int8_t
>
inline
__host__
__device__
constexpr
bhalf_
t
t
ype
_c
onvert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
{
{
__host__
__device__
bhalf_t
operator
()(
int8_t
&
x
)
const
float
x_fp32
=
static_cast
<
float
>
(
x
);
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
TypeConvert
<
bhalf_t
,
float
>
{}(
x_fp32
);
}
};
// class TypeConvert
return
type_convert
<
bhalf_t
>
(
x_fp32
);
// {
}
// public:
// // constructor
// __host__ __device__ TypeConvert()
// {
// BF16ConvertRTN_ = false; // use round to zero by default
// }
// // switch bf16 conversion mode to rtn
// __host__ __device__ void SetBF16ConvertRTN() { BF16ConvertRTN_ = true; }
// // switch bf16 conversion mode to rtz
// __host__ __device__ void SetBF16ConvertRTZ() { BF16ConvertRTN_ = false; }
// // convert for all types except bf16
// template <typename Y, typename X>
// __host__ __device__ constexpr Y convert(X x)
// {
// static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
// return static_cast<Y>(x);
// }
// // convert bfp16 to fp32
// template <>
// inline __host__ __device__ constexpr float convert<float, bhalf_t>(bhalf_t x)
// {
// union
// {
// uint32_t int32;
// float fp32;
// } u = {uint32_t(x) << 16};
// return u.fp32;
// }
// // convert fp32 to bfp16
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, float>(float x)
// {
// // if using rtn
// if(BF16ConvertRTN_)
// {
// union
// {
// float fp32;
// uint32_t int32;
// } u = {x};
// // When the exponent bits are not all 1s, then the value is zero, normal,
// // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// // This causes the bfloat16's mantissa to be incremented by 1 if the 16
// // least significant bits of the float mantissa are greater than 0x8000,
// // or if they are equal to 0x8000 and the least significant bit of the
// // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// // has the value 0x7f, then incrementing it causes it to become 0x00 and
// // the exponent is incremented by one, which is the next higher FP value
// // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// // incrementing it causes it to become an exponent of 0xFF and a mantissa
// // of 0x00, which is Inf, the next higher value to the unrounded value.
// bool flag0 = ~u.int32 & 0x7f800000;
// // When all of the exponent bits are 1, the value is Inf or NaN.
// // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// // mantissa bit. Quiet NaN is indicated by the most significant mantissa
// // bit being 1. Signaling NaN is indicated by the most significant
// // mantissa bit being 0 but some other bit(s) being 1. If any of the
// // lower 16 bits of the mantissa are 1, we set the least significant bit
// // of the bfloat16 mantissa, in order to preserve signaling NaN in case
// // the bfloat16's mantissa bits are all 0.
// bool flag1 = !flag0 && (u.int32 & 0xffff);
// u.int32 +=
// flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
// u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
// return uint16_t(u.int32 >> 16);
// }
// // if using rtz
// else
// {
// union
// {
// float fp32;
// uint32_t int32;
// } u = {x};
// return uint16_t(u.int32 >> 16);
// }
// }
// // convert bfp16 to fp16 via fp32
// template <>
// inline __host__ __device__ constexpr half_t convert<half_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<half_t>(x_fp32);
// }
// // convert fp16 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, half_t>(half_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// // convert bfp16 to int32 via fp32
// template <>
// inline __host__ __device__ constexpr int32_t convert<int32_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<int32_t>(x_fp32);
// }
// // convert int32 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int32_t>(int32_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// // convert bfp16 to int8 via fp32
// template <>
// inline __host__ __device__ constexpr int8_t convert<int8_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<int8_t>(x_fp32);
// }
// // convert int8 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int8_t>(int8_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// private:
// bool BF16ConvertRTN_;
// };
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
...
include/ck/utility/inner_product.hpp
View file @
71a7ac8b
...
@@ -87,11 +87,10 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
...
@@ -87,11 +87,10 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
#else
#else
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
TypeConvert
type_convert
=
TypeConvert
();
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
.
convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
.
convert
<
int32_t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
});
});
#endif
#endif
}
}
...
@@ -139,8 +138,7 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
...
@@ -139,8 +138,7 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
template
<
>
template
<
>
__device__
void
inner_product
<
int8_t
,
int8_t
,
int32_t
>
(
const
int8_t
&
a
,
const
int8_t
&
b
,
int32_t
&
c
)
__device__
void
inner_product
<
int8_t
,
int8_t
,
int32_t
>
(
const
int8_t
&
a
,
const
int8_t
&
b
,
int32_t
&
c
)
{
{
TypeConvert
type_convert
=
TypeConvert
();
c
+=
type_convert
<
int32_t
>
(
a
)
*
type_convert
<
int32_t
>
(
b
);
c
+=
type_convert
.
convert
<
int32_t
>
(
a
)
*
type_convert
.
convert
<
int32_t
>
(
b
);
}
}
template
<
>
template
<
>
...
@@ -176,11 +174,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
...
@@ -176,11 +174,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
TypeConvert
type_convert
=
TypeConvert
();
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
.
convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
.
convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
});
#endif
#endif
}
}
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
71a7ac8b
...
@@ -271,7 +271,7 @@ struct Tensor
...
@@ -271,7 +271,7 @@ struct Tensor
Tensor
<
OutT
>
ret
(
mDesc
);
Tensor
<
OutT
>
ret
(
mDesc
);
ck
::
ranges
::
transform
(
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
ck
::
ranges
::
transform
(
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
return
ck
::
T
ype
C
onvert
<
OutT
,
ck
::
remove_cvref_t
<
decltype
(
value
)
>>
{}
(
value
);
return
ck
::
t
ype
_c
onvert
<
OutT
>
(
value
);
});
});
return
ret
;
return
ret
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment