Unverified Commit e7471141 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Update passes to use offload_copy based on root module (#1875)

Needed to run multi-targeted program where "main" isn't the only root module. There could be many root modules other than main.
parent edc4bf53
......@@ -30,7 +30,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct module_pass_manager;
/**
* Replace `allocate` instructions with target allocations or output parameters.
......@@ -40,7 +40,7 @@ struct replace_allocate
allocation_model model;
bool offload_copy = false;
std::string name() const { return "replace_allocate"; }
void apply(module& m) const;
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager
{
module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr;
module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
......@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual module* get_root_module() override
{
if(root_mod != nullptr)
return root_mod;
assert(prog);
return prog->get_main_module();
}
......@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
continue;
if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, &trace};
module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
......@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &trace}.run_pass(p);
module_pm{&mod, &mod, &trace}.run_pass(p);
}
}
......
......@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m.name() == "main")
if(m == *root_module)
return;
for(auto ins : iterator_for(m))
......
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
......@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
}
void replace_allocate::apply(module& m) const
void replace_allocate::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false;
bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m))
{
auto op = ins->get_operator();
......@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue;
auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{
auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param);
......
......@@ -30,7 +30,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct module_pass_manager;
namespace gpu {
......@@ -45,7 +45,7 @@ struct lowering
context* ctx;
bool offload_copy;
std::string name() const { return "gpu::lowering"; }
void apply(module& m) const;
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
......
......@@ -22,12 +22,19 @@
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/gpu/lowering.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
......@@ -35,17 +42,12 @@
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
#include <utility>
#include <functional>
#include <algorithm>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -53,8 +55,9 @@ namespace gpu {
struct miopen_apply
{
module* mod = nullptr;
const lowering* pass = nullptr;
module* mod = nullptr;
module_pass_manager* mpm = nullptr;
const lowering* pass = nullptr;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{};
bool offload_copy = false;
......@@ -83,8 +86,7 @@ struct miopen_apply
auto& ctx = get_context();
int8_x4_format = get_int8_x4_format(ctx);
compute_fp32 = get_compute_fp32_flag();
// TODO: Set Offload copy based on root modules' compile options
offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous");
......@@ -376,7 +378,10 @@ struct miopen_apply
}
};
void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
void lowering::apply(module_pass_manager& mpm) const
{
miopen_apply{&mpm.get_module(), &mpm, this}.apply();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment