Commit 1b56b01b authored by Ted Themistokleous's avatar Ted Themistokleous
Browse files

Work in progress initial stream sync spelunking

parent 4918d769
......@@ -242,6 +242,11 @@ bool equal(const T& x, const T& y)
}
std::vector<argument> run(program& p, const parameter_map& params) { return p.eval(params); }
std::vector<argument>
run_async(program& p, const parameter_map& params, const execution_environment& exec_env)
{
return p.eval(params, exec_env);
}
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
......
......@@ -972,7 +972,8 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
}
/// Run the program using the inputs passed in
arguments eval(const program_parameters& pparams) const
arguments eval(const program_parameters& pparams,
const execution_environment& e = {nullptr, false}) const
{
migraphx_arguments_t pout;
call(&migraphx_program_run, &pout, this->get_handle_ptr(), pparams.get_handle_ptr());
......
......@@ -76,8 +76,13 @@ struct context
value to_value() const;
// (optional)
void from_value(const value& v);
// (optional)
any_ptr get_queue();
// Used for async streams
void wait_for(any_ptr queue);
void finish_on(any_ptr queue);
//
void finish() const;
};
......@@ -165,6 +170,16 @@ struct context
return (*this).private_detail_te_get_handle().get_queue();
}
void wait_for(any_ptr queue)
{
// TODO
}
void finish_on(any_ptr queue)
{
// TODO
}
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
......
......@@ -77,6 +77,7 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const;
std::vector<argument> eval(parameter_map params) const;
std::vector<argument> eval(parameter_map params, execution_environment exec_env) const;
std::size_t size() const;
......
......@@ -194,6 +194,12 @@ struct hip_device
std::unordered_map<std::string, argument> preallocations{};
};
struct excecution_environment
{
any_ptr queue = nullptr;
bool async = false;
};
struct context
{
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
......@@ -265,6 +271,9 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams);
}
void wait_for(any_ptr queue) {}
void finish_on(any_ptr queue) {}
any_ptr get_queue() { return get_stream().get(); }
private:
......
......@@ -87,13 +87,14 @@ migraphx_status try_(F f, bool output = true) // NOLINT
shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{
switch(t)
{
void wait_for() const
{call(&
case migraphx_shape_tuple_type: return shape::tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
#undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
}
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) #undef MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
}
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment