Commit 875bb4f7 authored by Paul's avatar Paul
Browse files

Fix tidy errors

parent 0ac28b96
...@@ -185,7 +185,7 @@ argument instruction::eval() const ...@@ -185,7 +185,7 @@ argument instruction::eval() const
void instruction::finalize(context& ctx) void instruction::finalize(context& ctx)
{ {
if (has_finalize(this->op)) if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs())); this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
} }
......
...@@ -81,17 +81,18 @@ shape miopen_convolution::compile(context& ctx, ...@@ -81,17 +81,18 @@ shape miopen_convolution::compile(context& ctx,
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Find convolution failed"); MIGRAPHX_THROW("Find convolution failed");
handle = ctx.get_stream().get_miopen(); handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo; algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {perf.memory}};
} }
void miopen_convolution::finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs) void miopen_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{ {
if (handle == ctx.get_stream().get_miopen()) if(handle == ctx.get_stream().get_miopen())
return; return;
// TODO: Check that workspace hasn't changed // TODO: Check that workspace hasn't changed
compile(ctx, output_shape, inputs); compile(ctx, output_shape, std::move(inputs));
} }
} // namespace gpu } // namespace gpu
......
...@@ -274,14 +274,8 @@ struct miopen_conv_bias ...@@ -274,14 +274,8 @@ struct miopen_conv_bias
return f.execute(ctx, fargs, args[0], args[4]); return f.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>&) void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
{ shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
f.compile(ctx);
}
shape get_workspace(context& ctx)
{
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
...@@ -321,14 +315,8 @@ struct miopen_conv_bias_relu ...@@ -321,14 +315,8 @@ struct miopen_conv_bias_relu
miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0); miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
return f.execute(ctx, fargs, args[0], args[4]); return f.execute(ctx, fargs, args[0], args[4]);
} }
void finalize(context& ctx, const shape&, const std::vector<shape>&) void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
{ shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
f.compile(ctx);
}
shape get_workspace(context& ctx)
{
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
}; };
......
...@@ -194,7 +194,8 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -194,7 +194,8 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
} }
template <class T> template <class T>
auto finalize_op(rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input) auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void()) -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{ {
x.finalize(auto_any_cast(ctx), output_shape, input); x.finalize(auto_any_cast(ctx), output_shape, input);
...@@ -202,7 +203,8 @@ auto finalize_op(rank<1>, T& x, context& ctx, const shape& output_shape, const s ...@@ -202,7 +203,8 @@ auto finalize_op(rank<1>, T& x, context& ctx, const shape& output_shape, const s
template <class T> template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&) void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{} {
}
template <class T> template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input) void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
...@@ -211,11 +213,8 @@ void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -211,11 +213,8 @@ void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vecto
} }
template <class T> template <class T>
auto has_finalize_op(rank<1>, auto has_finalize_op(
T& x, rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
context& ctx,
const shape& output_shape,
const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{}); -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T> template <class T>
...@@ -223,8 +222,11 @@ auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shap ...@@ -223,8 +222,11 @@ auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shap
-> std::false_type; -> std::false_type;
template <class T> template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op( auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
rank<1>{}, std::declval<T&>(), std::declval<context&>(), std::declval<const shape&>(), std::declval<std::vector<shape>>())) std::declval<T&>(),
std::declval<context&>(),
std::declval<const shape&>(),
std::declval<std::vector<shape>>()))
{ {
return {}; return {};
} }
...@@ -240,7 +242,11 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op( ...@@ -240,7 +242,11 @@ auto has_finalize_op(const T&) -> decltype(has_finalize_op(
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'output_alias_op'),
virtual('finalize', ctx = 'context&', output = 'const shape&', input = 'const std::vector<shape>&', default = 'finalize_op'), virtual('finalize',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<shape>&',
default = 'finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment