Commit 75f68cc4 authored by Paul's avatar Paul
Browse files

Add flag to drop additional args

parent 867e588b
...@@ -40,18 +40,20 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); ...@@ -40,18 +40,20 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op")); return pack(f(self.op, "op"), f(self.additional_args, "additional_args"));
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
inputs.pop_back(); // Pop off additional args
inputs.resize(inputs.size() - additional_args);
return op.compute_shape(inputs, mods); return op.compute_shape(inputs, mods);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment