Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4c6c750a
Commit
4c6c750a
authored
Apr 06, 2023
by
Rosty Geyyer
Browse files
Add TypeConvert class and start refactoring
parent
dbd8f94b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
107 deletions
+147
-107
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+135
-100
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+8
-5
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+4
-2
No files found.
include/ck/utility/data_type.hpp
100644 → 100755
View file @
4c6c750a
...
@@ -942,19 +942,34 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
...
@@ -942,19 +942,34 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x32_t
=
typename
vector_type
<
int8_t
,
32
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// Convert X to Y
class
TypeConvert
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
public:
// constructor
__host__
__device__
TypeConvert
()
{
BF16ConvertRTN_
=
false
;
// use round to zero by default
}
// switch bf16 conversion mode to rtn
__host__
__device__
void
SetBF16ConvertRTN
()
{
BF16ConvertRTN_
=
true
;
}
// switch bf16 conversion mode to rtz
__host__
__device__
void
SetBF16ConvertRTZ
()
{
BF16ConvertRTN_
=
false
;
}
// convert for all types except bf16
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
// convert bfp16 to fp32
// convert bfp16 to fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
float
type_
convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
float
convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
{
{
union
union
{
{
uint32_t
int32
;
uint32_t
int32
;
...
@@ -962,12 +977,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
...
@@ -962,12 +977,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
}
u
=
{
uint32_t
(
x
)
<<
16
};
}
u
=
{
uint32_t
(
x
)
<<
16
};
return
u
.
fp32
;
return
u
.
fp32
;
}
}
// convert fp32 to bfp16
// convert fp32 to bfp16
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
float
>
(
float
x
)
{
{
// if using rtn
if
(
BF16ConvertRTN_
)
{
union
union
{
{
float
fp32
;
float
fp32
;
...
@@ -1002,65 +1020,82 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
...
@@ -1002,65 +1020,82 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
// the bfloat16's mantissa bits are all 0.
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
// if using rtz
else
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// convert bfp16 to fp16 via fp32
return
uint16_t
(
u
.
int32
>>
16
);
template
<
>
}
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
}
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
// convert bfp16 to fp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
half_t
convert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
convert
<
float
>
(
x
);
return
static_cast
<
half_t
>
(
x_fp32
);
return
static_cast
<
half_t
>
(
x_fp32
);
}
}
// convert fp16 to bfp16 via fp32
// convert fp16 to bfp16 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_
convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_
convert
<
bhalf_t
>
(
x_fp32
);
return
convert
<
bhalf_t
>
(
x_fp32
);
}
}
// convert bfp16 to int32 via fp32
// convert bfp16 to int32 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_
convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
int32_t
convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
{
float
x_fp32
=
type_
convert
<
float
>
(
x
);
float
x_fp32
=
convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
}
// convert int32 to bfp16 via fp32
// convert int32 to bfp16 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_
convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_
convert
<
bhalf_t
>
(
x_fp32
);
return
convert
<
bhalf_t
>
(
x_fp32
);
}
}
// convert bfp16 to int8 via fp32
// convert bfp16 to int8 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_
convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
inline
__host__
__device__
constexpr
int8_t
convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
{
{
float
x_fp32
=
type_
convert
<
float
>
(
x
);
float
x_fp32
=
convert
<
float
>
(
x
);
return
static_cast
<
int8_t
>
(
x_fp32
);
return
static_cast
<
int8_t
>
(
x_fp32
);
}
}
// convert int8 to bfp16 via fp32
// convert int8 to bfp16 via fp32
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_
convert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
inline
__host__
__device__
constexpr
bhalf_t
convert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
{
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
convert
<
bhalf_t
>
(
x_fp32
);
}
}
private:
bool
BF16ConvertRTN_
;
};
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
...
include/ck/utility/inner_product.hpp
100644 → 100755
View file @
4c6c750a
...
@@ -87,10 +87,11 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
...
@@ -87,10 +87,11 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
#else
#else
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
a_vector
{
a
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
const
vector_type
<
half_t
,
2
>
b_vector
{
b
};
TypeConvert
type_convert
=
TypeConvert
();
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
c
+=
type_convert
.
convert
<
int32_t
>
(
a_vector
.
AsType
<
half_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
type_convert
.
convert
<
int32_t
>
(
b_vector
.
AsType
<
half_t
>
()[
i
]);
});
});
#endif
#endif
}
}
...
@@ -138,7 +139,8 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
...
@@ -138,7 +139,8 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
template
<
>
template
<
>
__device__
void
inner_product
<
int8_t
,
int8_t
,
int32_t
>
(
const
int8_t
&
a
,
const
int8_t
&
b
,
int32_t
&
c
)
__device__
void
inner_product
<
int8_t
,
int8_t
,
int32_t
>
(
const
int8_t
&
a
,
const
int8_t
&
b
,
int32_t
&
c
)
{
{
c
+=
type_convert
<
int32_t
>
(
a
)
*
type_convert
<
int32_t
>
(
b
);
TypeConvert
type_convert
=
TypeConvert
();
c
+=
type_convert
.
convert
<
int32_t
>
(
a
)
*
type_convert
.
convert
<
int32_t
>
(
b
);
}
}
template
<
>
template
<
>
...
@@ -174,10 +176,11 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
...
@@ -174,10 +176,11 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else
#else
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
TypeConvert
type_convert
=
TypeConvert
();
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
c
+=
type_convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
c
+=
type_convert
.
convert
<
int32_t
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
type_convert
.
convert
<
int32_t
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
});
#endif
#endif
}
}
...
...
library/include/ck/library/utility/host_tensor.hpp
100644 → 100755
View file @
4c6c750a
...
@@ -270,8 +270,10 @@ struct Tensor
...
@@ -270,8 +270,10 @@ struct Tensor
{
{
Tensor
<
OutT
>
ret
(
mDesc
);
Tensor
<
OutT
>
ret
(
mDesc
);
ck
::
ranges
::
transform
(
ck
::
ranges
::
transform
(
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
mData
,
ret
.
mData
.
begin
(),
[](
auto
value
)
{
return
ck
::
type_convert
<
OutT
>
(
value
);
});
ck
::
TypeConvert
type_convert
=
ck
::
TypeConvert
();
return
type_convert
.
convert
<
OutT
>
(
value
);
});
return
ret
;
return
ret
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment