Commit 6a09f672 authored by mei-ye's avatar mei-ye
Browse files

add back const to finish, recode is_split_point and is_merge_point

parent 49f483af
...@@ -20,7 +20,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,7 +20,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context struct context
{ {
/// Wait for any tasks in the context to complete /// Wait for any tasks in the context to complete
void finish(); void finish() const;
}; };
#else #else
...@@ -30,7 +30,7 @@ struct context ...@@ -30,7 +30,7 @@ struct context
* *
* struct context * struct context
* { * {
* void finish() ; * void finish() const;
* }; * };
* *
*/ */
...@@ -92,7 +92,7 @@ struct context ...@@ -92,7 +92,7 @@ struct context
return private_detail_te_get_handle().type(); return private_detail_te_get_handle().type();
} }
void finish() void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish(); (*this).private_detail_te_get_handle().finish();
...@@ -111,7 +111,7 @@ struct context ...@@ -111,7 +111,7 @@ struct context
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual void finish() = 0; virtual void finish() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -142,7 +142,7 @@ struct context ...@@ -142,7 +142,7 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -149,28 +149,34 @@ void dom_info::compute_dom(bool reversed) ...@@ -149,28 +149,34 @@ void dom_info::compute_dom(bool reversed)
bool dom_info::is_split_point(instruction_ref ins) bool dom_info::is_split_point(instruction_ref ins)
{ {
std::set<int> stream_set; int stream = -1;
for(auto&& arg : ins->outputs()) for(auto&& arg : ins->outputs())
{ {
int arg_stream = arg->get_stream(); int arg_stream = arg->get_stream();
if(arg_stream >= 0) if(arg_stream < 0)
stream_set.insert(arg_stream); continue;
if((stream >= 0) && (arg_stream != stream))
return true;
stream = arg_stream;
} }
return (stream_set.size() > 1); return false;
} }
// Identify merge points. A merge point has more than one // Identify merge points. A merge point has more than one
// inputs that are executed in different streams. // inputs that are executed in different streams.
bool dom_info::is_merge_point(instruction_ref ins) bool dom_info::is_merge_point(instruction_ref ins)
{ {
std::set<int> stream_set; int stream = -1;
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
int arg_stream = arg->get_stream(); int arg_stream = arg->get_stream();
if(arg_stream >= 0) if(arg_stream < 0)
stream_set.insert(arg_stream); continue;
if((stream >= 0) && (arg_stream != stream))
return true;
stream = arg_stream;
} }
return (stream_set.size() > 1); return false;
} }
// Propagate split points through the graph and identify concurrent instructions. // Propagate split points through the graph and identify concurrent instructions.
......
...@@ -9,7 +9,7 @@ namespace cpu { ...@@ -9,7 +9,7 @@ namespace cpu {
struct context struct context
{ {
void finish() {} void finish() const {}
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -148,7 +148,7 @@ struct context ...@@ -148,7 +148,7 @@ struct context
{ {
context(std::size_t n = 0) : current_device(std::make_shared<hip_device>(n)) {} context(std::size_t n = 0) : current_device(std::make_shared<hip_device>(n)) {}
hip_device& get_current_device() hip_device& get_current_device() const
{ {
assert(current_device != nullptr); assert(current_device != nullptr);
return *current_device; return *current_device;
...@@ -165,7 +165,7 @@ struct context ...@@ -165,7 +165,7 @@ struct context
void wait_event(int event) { get_current_device().wait_event(event); } void wait_event(int event) { get_current_device().wait_event(event); }
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() void finish() const
{ {
get_current_device().stream_sync(); get_current_device().stream_sync();
gpu_sync(); gpu_sync();
......
...@@ -20,14 +20,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,14 +20,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct context struct context
{ {
/// Wait for any tasks in the context to complete /// Wait for any tasks in the context to complete
void finish(); void finish() const;
}; };
#else #else
<% <%
interface('context', interface('context',
virtual('finish', returns='void'), virtual('finish', returns='void', const=True),
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment