Commit b546a9f3 authored by Paul's avatar Paul
Browse files

Remove create_params

parent 44e6630c
...@@ -71,6 +71,8 @@ struct program ...@@ -71,6 +71,8 @@ struct program
shape get_parameter_shape(std::string name); shape get_parameter_shape(std::string name);
std::unordered_map<std::string, shape> get_parameter_shapes() const;
argument eval(parameter_map params) const; argument eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
......
...@@ -115,6 +115,20 @@ shape program::get_parameter_shape(std::string name) ...@@ -115,6 +115,20 @@ shape program::get_parameter_shape(std::string name)
return {}; return {};
} }
std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins:impl->instructions)
{
if(ins.op.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.op).parameter;
result[name] = ins.result;
}
}
return result;
}
bool program::has_instruction(instruction_ref ins) const bool program::has_instruction(instruction_ref ins) const
{ {
return std::find_if( return std::find_if(
......
...@@ -19,7 +19,12 @@ migraph::argument run_cpu() ...@@ -19,7 +19,12 @@ migraph::argument run_cpu()
V v; V v;
auto p = v.create_program(); auto p = v.create_program();
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
return p.eval(v.create_params()); migraph::program::parameter_map m;
for(auto&& x:p.get_parameter_shapes())
{
m[x.first] = migraph::generate_argument(x.second);
}
return p.eval(m);
} }
template <class V> template <class V>
...@@ -29,14 +34,12 @@ migraph::argument run_gpu() ...@@ -29,14 +34,12 @@ migraph::argument run_gpu()
auto p = v.create_program(); auto p = v.create_program();
p.compile(migraph::gpu::target{}); p.compile(migraph::gpu::target{});
auto m = v.create_params(); migraph::program::parameter_map m;
for(auto&& e : m) for(auto&& x:p.get_parameter_shapes())
{ {
e.second = migraph::gpu::to_gpu(e.second); m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
} }
m["output"] = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
return migraph::gpu::from_gpu(p.eval(m)); return migraph::gpu::from_gpu(p.eval(m));
} }
...@@ -61,8 +64,6 @@ struct test_literals ...@@ -61,8 +64,6 @@ struct test_literals
p.add_instruction(migraph::activation{"relu"}, conv); p.add_instruction(migraph::activation{"relu"}, conv);
return p; return p;
} }
migraph::program::parameter_map create_params() const { return {}; }
}; };
struct test_add struct test_add
...@@ -76,14 +77,6 @@ struct test_add ...@@ -76,14 +77,6 @@ struct test_add
p.add_instruction(migraph::add{}, x, y); p.add_instruction(migraph::add{}, x, y);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {3}});
m["y"] = migraph::generate_argument({migraph::shape::float_type, {3}});
return m;
}
}; };
struct test_add_broadcast struct test_add_broadcast
...@@ -98,14 +91,6 @@ struct test_add_broadcast ...@@ -98,14 +91,6 @@ struct test_add_broadcast
p.add_instruction(migraph::add{}, x, by); p.add_instruction(migraph::add{}, x, by);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {2, 2, 3}});
m["y"] = migraph::generate_argument({migraph::shape::float_type, {2, 2}});
return m;
}
}; };
struct test_conv_relu struct test_conv_relu
...@@ -120,14 +105,6 @@ struct test_conv_relu ...@@ -120,14 +105,6 @@ struct test_conv_relu
p.add_instruction(migraph::activation{"relu"}, conv); p.add_instruction(migraph::activation{"relu"}, conv);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
m["w"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
return m;
}
}; };
struct test_conv_pooling struct test_conv_pooling
...@@ -144,14 +121,6 @@ struct test_conv_pooling ...@@ -144,14 +121,6 @@ struct test_conv_pooling
p.add_instruction(migraph::activation{"relu"}, pooling); p.add_instruction(migraph::activation{"relu"}, pooling);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 32, 32}});
m["w"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 3, 3}});
return m;
}
}; };
struct test_gemm struct test_gemm
...@@ -164,14 +133,6 @@ struct test_gemm ...@@ -164,14 +133,6 @@ struct test_gemm
p.add_instruction(migraph::gemm{}, a, b); p.add_instruction(migraph::gemm{}, a, b);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["a"] = migraph::generate_argument({migraph::shape::float_type, {4, 5}});
m["b"] = migraph::generate_argument({migraph::shape::float_type, {5, 3}});
return m;
}
}; };
struct test_contiguous struct test_contiguous
...@@ -184,14 +145,6 @@ struct test_contiguous ...@@ -184,14 +145,6 @@ struct test_contiguous
p.add_instruction(migraph::contiguous{}, x); p.add_instruction(migraph::contiguous{}, x);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] =
migraph::generate_argument({migraph::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}});
return m;
}
}; };
struct test_transpose struct test_transpose
...@@ -206,13 +159,6 @@ struct test_transpose ...@@ -206,13 +159,6 @@ struct test_transpose
p.add_instruction(migraph::contiguous{}, l); p.add_instruction(migraph::contiguous{}, l);
return p; return p;
} }
migraph::program::parameter_map create_params() const
{
migraph::program::parameter_map m;
m["x"] = migraph::generate_argument({migraph::shape::float_type, {4, 3, 4, 4}});
return m;
}
}; };
int main() int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment