Commit b546a9f3 authored by Paul's avatar Paul
Browse files

Remove create_params

parent 44e6630c
......@@ -71,6 +71,8 @@ struct program
shape get_parameter_shape(std::string name);
std::unordered_map<std::string, shape> get_parameter_shapes() const;
argument eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const;
......
......@@ -115,6 +115,20 @@ shape program::get_parameter_shape(std::string name)
return {};
}
std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins:impl->instructions)
{
if(ins.op.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result;
}
}
return result;
}
bool program::has_instruction(instruction_ref ins) const
{
return std::find_if(
......
......@@ -19,7 +19,12 @@ migraph::argument run_cpu()
V v;
auto p = v.create_program();
p.compile(migraph::cpu::cpu_target{});
return p.eval(v.create_params());
migraph::program::parameter_map m;
for(auto&& x:p.get_parameter_shapes())
{
m[x.first] = migraph::generate_argument(x.second);
}
return p.eval(m);
}
template <class V>
......@@ -29,14 +34,12 @@ migraph::argument run_gpu()
auto p = v.create_program();
p.compile(migraph::gpu::target{});
auto m = v.create_params();
for(auto&& e : m)
migraph::program::parameter_map m;
for(auto&& x:p.get_parameter_shapes())
{
e.second = migraph::gpu::to_gpu(e.second);
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
m["output"] = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
return migraph::gpu::from_gpu(p.eval(m));
}
......@@ -61,8 +64,6 @@ struct test_literals
p.add_instruction(migraph::activation{"relu"}, conv);
return p;
}
migraph::program::parameter_map create_params() const { return {}; }
};
struct test_add
......@@ -76,14 +77,6 @@ struct test_add
p.add_instruction(migraph::add{}, x, y);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {3}});
m["y"] = migraph::generate_argument({migraph::shape::float_type, {3}});
return m;
}
};
struct test_add_broadcast
......@@ -98,14 +91,6 @@ struct test_add_broadcast
p.add_instruction(migraph::add{}, x, by);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {2, 2, 3}});
m["y"] = migraph::generate_argument({migraph::shape::float_type, {2, 2}});
return m;
}
};
struct test_conv_relu
......@@ -120,14 +105,6 @@ struct test_conv_relu
p.add_instruction(migraph::activation{"relu"}, conv);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
m["w"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
return m;
}
};
struct test_conv_pooling
......@@ -144,14 +121,6 @@ struct test_conv_pooling
p.add_instruction(migraph::activation{"relu"}, pooling);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 32, 32}});
m["w"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
return m;
}
};
struct test_gemm
......@@ -164,14 +133,6 @@ struct test_gemm
p.add_instruction(migraph::gemm{}, a, b);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["a"] = migraph::generate_argument({migraph::shape::float_type, {4, 5}});
m["b"] = migraph::generate_argument({migraph::shape::float_type, {5, 3}});
return m;
}
};
struct test_contiguous
......@@ -184,14 +145,6 @@ struct test_contiguous
p.add_instruction(migraph::contiguous{}, x);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] =
migraph::generate_argument({migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}});
return m;
}
};
struct test_transpose
......@@ -206,13 +159,6 @@ struct test_transpose
p.add_instruction(migraph::contiguous{}, l);
return p;
}
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 4, 4}});
return m;
}
};
int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment