Commit 6a09f672 authored by mei-ye's avatar mei-ye
Browse files

add back const to finish, recode is_split_point and is_merge_point

parent 49f483af
......@@ -20,7 +20,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context
{
/// Wait for any tasks in the context to complete
void finish();
void finish() const;
};
#else
......@@ -30,7 +30,7 @@ struct context
*
* struct context
* {
* void finish() ;
* void finish() const;
* };
*
*/
......@@ -92,7 +92,7 @@ struct context
return private_detail_te_get_handle().type();
}
void finish()
void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish();
......@@ -111,7 +111,7 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void finish() = 0;
virtual void finish() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -142,7 +142,7 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() override { private_detail_te_value.finish(); }
void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value;
};
......
......@@ -149,28 +149,34 @@ void dom_info::compute_dom(bool reversed)
bool dom_info::is_split_point(instruction_ref ins)
{
std::set<int> stream_set;
int stream = -1;
for(auto&& arg : ins->outputs())
{
int arg_stream = arg->get_stream();
if(arg_stream >= 0)
stream_set.insert(arg_stream);
if(arg_stream < 0)
continue;
if((stream >= 0) && (arg_stream != stream))
return true;
stream = arg_stream;
}
return (stream_set.size() > 1);
return false;
}
// Identify merge points. A merge point has more than one
// inputs that are executed in different streams.
bool dom_info::is_merge_point(instruction_ref ins)
{
std::set<int> stream_set;
int stream = -1;
for(auto&& arg : ins->inputs())
{
int arg_stream = arg->get_stream();
if(arg_stream >= 0)
stream_set.insert(arg_stream);
if(arg_stream < 0)
continue;
if((stream >= 0) && (arg_stream != stream))
return true;
stream = arg_stream;
}
return (stream_set.size() > 1);
return false;
}
// Propagate split points through the graph and identify concurrent instructions.
......
......@@ -9,7 +9,7 @@ namespace cpu {
struct context
{
void finish() {}
void finish() const {}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -148,7 +148,7 @@ struct context
{
context(std::size_t n = 0) : current_device(std::make_shared<hip_device>(n)) {}
hip_device& get_current_device()
hip_device& get_current_device() const
{
assert(current_device != nullptr);
return *current_device;
......@@ -165,7 +165,7 @@ struct context
void wait_event(int event) { get_current_device().wait_event(event); }
std::vector<argument> literals{};
void finish()
void finish() const
{
get_current_device().stream_sync();
gpu_sync();
......
......@@ -20,14 +20,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context
{
/// Wait for any tasks in the context to complete
void finish();
void finish() const;
};
#else
<%
interface('context',
virtual('finish', returns='void'),
virtual('finish', returns='void', const=True),
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment