Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
936e76b2
Commit
936e76b2
authored
Sep 25, 2023
by
Paul
Browse files
Add doulbe-precision evaluation
parent
2ee297d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
0 deletions
+57
-0
src/instruction.cpp
src/instruction.cpp
+57
-0
No files found.
src/instruction.cpp
View file @
936e76b2
...
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
...
...
@@ -331,8 +332,63 @@ bool instruction::can_eval() const
}
}
template
<
class
InstructionRef
>
static
bool
is_low_precision_float
(
InstructionRef
ins
)
{
auto
t
=
ins
->
get_shape
().
type
();
return
contains
({
shape
::
float_type
,
shape
::
half_type
},
t
);
}
argument
instruction
::
eval
(
bool
check_eval
)
const
{
#if 0
if (not this->can_eval())
return {};
auto r = fix<argument>([](auto self, auto ins) -> argument {
if (ins->name() == "@literal")
{
if (is_low_precision_float(ins))
{
auto dlit = transform(ins->get_literal(), [](auto x) -> double {return x;});
return dlit.get_argument();
}
return ins->get_literal().get_argument();
}
if (ins->name() == "convert")
{
if (is_low_precision_float(ins))
{
auto x = self(ins->inputs().front());
if(is_low_precision_float(ins->inputs().front()))
{
return x;
}
auto convert = make_op("convert", {{"target_type", to_value(shape::double_type)}});
auto s = convert.compute_shape({x.get_shape()});
return convert.compute(s, {x});
}
}
if(is_context_free(ins->get_operator()))
{
std::vector<argument> args;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(args),
[&](auto arg) { return self(arg); });
auto normal_op = ins->normalized_operator();
auto s = normal_op.compute_shape(to_shapes(args));
return normal_op.compute(s, args);
}
return {};
})(this);
auto convert = make_op("convert", {{"target_type", to_value(this->get_shape().type())}});
auto s = convert.compute_shape({r.get_shape()});
return convert.compute(s, {r});
#else
if
(
op
.
name
()
==
"@literal"
)
{
return
this
->
get_literal
().
get_argument
();
...
...
@@ -349,6 +405,7 @@ argument instruction::eval(bool check_eval) const
return
normalized_operator
().
compute
(
result
,
args
);
}
return
{};
#endif
}
void
instruction
::
finalize
(
context
&
ctx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment