Commit 168584b7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from gpu_div branch

parents 6c9f9277 8d059502
...@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins) bool skip_propogate(instruction_ref ins)
{ {
if(ins->name() == "@literal") if(ins->name() == "contiguous")
return true; return skip_propogate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const ...@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const
ins->outputs().end()); ins->outputs().end());
for(auto child : children) for(auto child : children)
{ {
if(skip_propogate(child)) if(child->name() == "@literal" or skip_propogate(child))
{ {
self(child); self(child);
continue; continue;
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void div(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return y / x; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x / y; });
} }
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)( nary(stream, result, arg1, arg2)(
[](auto e, auto b) { return ::pow(to_hip_type(b), to_hip_type(e)); }); [](auto b, auto e) { return ::pow(to_hip_type(b), to_hip_type(e)); });
} }
} // namespace device } // namespace device
......
...@@ -8,7 +8,7 @@ namespace device { ...@@ -8,7 +8,7 @@ namespace device {
void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) { return y - x; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x - y; });
} }
} // namespace device } // namespace device
......
...@@ -88,7 +88,7 @@ struct binary_device : oper<Derived> ...@@ -88,7 +88,7 @@ struct binary_device : oper<Derived>
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
F(ctx.get_stream().get(), args[2], args[1], args[0]); F(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2]; return args[2];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment