Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
faef98bf
"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "2a86fa207101c1c2f727fb9e04437b6b075e0788"
Commit
faef98bf
authored
May 22, 2019
by
Shucai Xiao
Browse files
reduce the rounding error in converting to int8
parent
41344324
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
29 deletions
+54
-29
src/include/migraphx/op/convert.hpp
src/include/migraphx/op/convert.hpp
+1
-0
src/targets/gpu/device/convert.cpp
src/targets/gpu/device/convert.cpp
+53
-29
No files found.
src/include/migraphx/op/convert.hpp
View file @
faef98bf
...
...
@@ -42,6 +42,7 @@ struct convert : unary<convert>
float
res
=
scale
*
x
+
shift
;
if
(
target_type
==
shape
::
int8_type
)
{
res
=
res
+
0.5
f
;
res
=
res
>
127.0
?
127.0
:
res
;
res
=
res
<
-
128.0
?
-
128.0
:
res
;
}
...
...
src/targets/gpu/device/convert.cpp
View file @
faef98bf
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/device/nary.hpp>
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
convert
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
float
scale
,
float
shift
,
shape
::
type_t
target_type
)
namespace
op
{
struct
convert
:
unary
<
convert
>
{
result
.
visit
([
&
](
auto
output
)
{
arg
.
visit
([
&
](
auto
input
)
{
const
auto
*
input_ptr
=
device_cast
(
input
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
if
(
target_type
==
shape
::
int8_type
)
shape
::
type_t
target_type
=
shape
::
half_type
;
float
scale
=
1.0
f
;
float
shift
=
0.0
f
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
target_type
,
"target_type"
),
f
(
self
.
scale
,
"scale"
),
f
(
self
.
shift
,
"shift"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
gs_launch
(
stream
,
result
.
get_shape
().
elements
())([
=
](
auto
i
)
{
output_ptr
[
i
]
=
std
::
min
<
int8_t
>
(
std
::
max
<
float
>
(
-
128
,
input_ptr
[
i
]
*
scale
+
shift
),
127
);
});
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
{
target_type
,
inputs
.
at
(
0
).
lens
(),
inputs
.
at
(
0
).
strides
()};
}
else
auto
apply
()
const
{
return
[
&
](
auto
x
)
{
float
res
=
scale
*
x
+
shift
;
if
(
target_type
==
shape
::
int8_type
)
{
gs_launch
(
stream
,
result
.
get_shape
().
elements
())(
[
=
](
auto
i
)
{
output_ptr
[
i
]
=
input_ptr
[
i
]
*
scale
+
shift
;
});
res
=
res
+
0.5
f
;
res
=
res
>
127.0
?
127.0
:
res
;
res
=
res
<
-
128.0
?
-
128.0
:
res
;
}
});
});
}
}
// namespace device
}
// namespace gpu
return
res
;
};
}
convert
(
shape
::
type_t
t
)
:
target_type
{
t
}
{}
convert
(
shape
::
type_t
t
,
float
sle
,
float
sft
)
:
target_type
{
t
},
scale
{
sle
},
shift
{
sft
}
{}
convert
()
{}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment