Commit 1b56b01b authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Work in progress initial stream sync spelunking

parent 4918d769
...@@ -242,6 +242,11 @@ bool equal(const T& x, const T& y) ...@@ -242,6 +242,11 @@ bool equal(const T& x, const T& y)
} }
std::vector<argument> run(program& p, const parameter_map& params) { return p.eval(params); } std::vector<argument> run(program& p, const parameter_map& params) { return p.eval(params); }
std::vector<argument>
run_async(program& p, const parameter_map& params, const execution_environment& exec_env)
{
return p.eval(params, exec_env);
}
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); } std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
......
...@@ -972,7 +972,8 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -972,7 +972,8 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
} }
/// Run the program using the inputs passed in /// Run the program using the inputs passed in
arguments eval(const program_parameters& pparams) const arguments eval(const program_parameters& pparams,
const execution_environment& e = {nullptr, false}) const
{ {
migraphx_arguments_t pout; migraphx_arguments_t pout;
call(&migraphx_program_run, &pout, this->get_handle_ptr(), pparams.get_handle_ptr()); call(&migraphx_program_run, &pout, this->get_handle_ptr(), pparams.get_handle_ptr());
......
...@@ -76,8 +76,13 @@ struct context ...@@ -76,8 +76,13 @@ struct context
value to_value() const; value to_value() const;
// (optional) // (optional)
void from_value(const value& v); void from_value(const value& v);
// (optional) // (optional)
any_ptr get_queue(); any_ptr get_queue();
// Used for async streams
void wait_for(any_ptr queue);
void finish_on(any_ptr queue);
// //
void finish() const; void finish() const;
}; };
...@@ -165,6 +170,16 @@ struct context ...@@ -165,6 +170,16 @@ struct context
return (*this).private_detail_te_get_handle().get_queue(); return (*this).private_detail_te_get_handle().get_queue();
} }
void wait_for(any_ptr queue)
{
// TODO
}
void finish_on(any_ptr queue)
{
// TODO
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
......
...@@ -77,6 +77,7 @@ struct program ...@@ -77,6 +77,7 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const; std::unordered_map<std::string, shape> get_parameter_shapes() const;
std::vector<argument> eval(parameter_map params) const; std::vector<argument> eval(parameter_map params) const;
std::vector<argument> eval(parameter_map params, execution_environment exec_env) const;
std::size_t size() const; std::size_t size() const;
......
...@@ -194,6 +194,12 @@ struct hip_device ...@@ -194,6 +194,12 @@ struct hip_device
std::unordered_map<std::string, argument> preallocations{}; std::unordered_map<std::string, argument> preallocations{};
}; };
struct excecution_environment
{
any_ptr queue = nullptr;
bool async = false;
};
struct context struct context
{ {
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1)) context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
...@@ -265,6 +271,9 @@ struct context ...@@ -265,6 +271,9 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
void wait_for(any_ptr queue) {}
void finish_on(any_ptr queue) {}
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
private: private:
......
...@@ -87,13 +87,14 @@ migraphx_status try_(F f, bool output = true) // NOLINT ...@@ -87,13 +87,14 @@ migraphx_status try_(F f, bool output = true) // NOLINT
shape::type_t to_shape_type(migraphx_shape_datatype_t t) shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{ {
switch(t) switch(t)
{
void wait_for() const
{call(&
case migraphx_shape_tuple_type: return shape::tuple_type; case migraphx_shape_tuple_type: return shape::tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x; case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) #undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT }
}
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment