Unverified Commit e7471141 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Update passes to use offload_copy based on root module (#1875)

Needed to run multi-targeted program where "main" isn't the only root module. There could be many root modules other than main.
parent edc4bf53
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
/** /**
* Replace `allocate` instructions with target allocations or output parameters. * Replace `allocate` instructions with target allocations or output parameters.
...@@ -40,7 +40,7 @@ struct replace_allocate ...@@ -40,7 +40,7 @@ struct replace_allocate
allocation_model model; allocation_model model;
bool offload_copy = false; bool offload_copy = false;
std::string name() const { return "replace_allocate"; } std::string name() const { return "replace_allocate"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod = nullptr; module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr; tracer* t = nullptr;
module* common_parent = nullptr; module* common_parent = nullptr;
program* prog = nullptr; program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {} module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
{ {
...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager ...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual module* get_root_module() override virtual module* get_root_module() override
{ {
if(root_mod != nullptr)
return root_mod;
assert(prog); assert(prog);
return prog->get_main_module(); return prog->get_main_module();
} }
...@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas ...@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
continue; continue;
if(not visited.insert(mod).second) if(not visited.insert(mod).second)
continue; continue;
module_pm mpm{mod, &trace}; module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog; mpm.prog = &prog;
auto parents = range(tree.equal_range(mod)); auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents); auto nparents = distance(parents);
...@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, &trace}.run_pass(p); module_pm{&mod, &mod, &trace}.run_pass(p);
} }
} }
......
...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const ...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module(); module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module(); module_ref root_module = mpm.get_root_module();
if(m.name() == "main") if(m == *root_module)
return; return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio ...@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args); mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
} }
void replace_allocate::apply(module& m) const void replace_allocate::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m); auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false; bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
auto op = ins->get_operator(); auto op = ins->get_operator();
...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const ...@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue; continue;
auto s = ins->get_shape(); auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{ {
auto out_param = m.add_parameter(mod_output_names[ins], s); auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param); m.replace_instruction(ins, out_param);
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
namespace gpu { namespace gpu {
...@@ -45,7 +45,7 @@ struct lowering ...@@ -45,7 +45,7 @@ struct lowering
context* ctx; context* ctx;
bool offload_copy; bool offload_copy;
std::string name() const { return "gpu::lowering"; } std::string name() const { return "gpu::lowering"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -22,12 +22,19 @@ ...@@ -22,12 +22,19 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator> #include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
...@@ -35,17 +42,12 @@ ...@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -53,8 +55,9 @@ namespace gpu { ...@@ -53,8 +55,9 @@ namespace gpu {
struct miopen_apply struct miopen_apply
{ {
module* mod = nullptr; module* mod = nullptr;
const lowering* pass = nullptr; module_pass_manager* mpm = nullptr;
const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
bool offload_copy = false; bool offload_copy = false;
...@@ -83,8 +86,7 @@ struct miopen_apply ...@@ -83,8 +86,7 @@ struct miopen_apply
auto& ctx = get_context(); auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx); int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag(); compute_fp32 = get_compute_fp32_flag();
// TODO: Set Offload copy based on root modules' compile options offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
...@@ -376,7 +378,10 @@ struct miopen_apply ...@@ -376,7 +378,10 @@ struct miopen_apply
} }
}; };
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); } void lowering::apply(module_pass_manager& mpm) const
{
miopen_apply{&mpm.get_module(), &mpm, this}.apply();
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment