Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
34348bd1
Commit
34348bd1
authored
Apr 19, 2023
by
Rosty Geyyer
Browse files
Add a flag to pick converion method
parent
2797227f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
8 deletions
+70
-8
include/ck/ck.hpp
include/ck/ck.hpp
+6
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+36
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+8
-4
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+20
-1
No files found.
include/ck/ck.hpp
View file @
34348bd1
...
...
@@ -173,6 +173,12 @@
#define CK_WORKAROUND_DENORM_FIX 0
#endif
// flag to enable high precision data conversion
// 0 - fast, 1 - high precision
#ifndef CK_EXPERIMENTAL_CONVERT_PRECISION
#define CK_EXPERIMENTAL_CONVERT_PRECISION 1
#endif
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
34348bd1
...
...
@@ -89,15 +89,48 @@ struct UnaryConvert
struct
UnaryConvertPrecision
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
Y
>
(
x
);
y
=
type_convert_precision
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert_precision
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
y
=
type_convert_precision
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
double
,
double
>
(
double
&
y
,
const
double
&
x
)
const
{
y
=
type_convert_precision
<
double
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert_precision
<
int8_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert_precision
<
bhalf_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert_
bf16_rtn
(
x
);
y
=
type_convert_
precision
<
bhalf_t
>
(
x
);
}
};
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
34348bd1
...
...
@@ -6,11 +6,10 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace
ck
{
namespace
detail
{
...
...
@@ -348,9 +347,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1
});
}
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// pick the right conversion method
#if CK_EXPERIMENTAL_CONVERT_PRECISION
using
UnaryConvert
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvertPrecision
;
#else
using
UnaryConvert
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
#endif
// convert from SrcData to DstData here
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
{}(
dst_thread_scratch_
(
idx
),
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
UnaryConvert
{}(
dst_thread_scratch_
(
idx
),
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
#endif
}
...
...
include/ck/utility/data_type.hpp
View file @
34348bd1
...
...
@@ -1031,8 +1031,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// Convert X to Y with highest possible precision
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_precision
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
// Convert fp32 to bf16 with RTN if higher precision is needed
__host__
__device__
constexpr
bhalf_t
type_convert_bf16_rtn
(
float
x
)
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert_precision
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
...
...
@@ -1074,6 +1084,15 @@ __host__ __device__ constexpr bhalf_t type_convert_bf16_rtn(float x)
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert_precision
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert_precision
<
bhalf_t
>
(
x_fp32
);
}
template
<
typename
T
>
struct
NumericLimits
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment