Commit 9164116a authored by Paul's avatar Paul
Browse files

Add a pass to remove allocations on the gpu

parent aa5e156d
......@@ -18,6 +18,7 @@ target_link_libraries(migraph_device migraph hip::device)
target_include_directories(migraph_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
add_library(migraph_gpu
eliminate_allocation.cpp
eliminate_workspace.cpp
hip.cpp
target.cpp
......
#include <migraph/gpu/eliminate_allocation.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/stringutils.hpp>
namespace migraph {
namespace gpu {
void eliminate_allocation::apply(program& p) const
{
std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p))
{
if(ins->op.name() != "hip::allocate")
continue;
allocs.emplace_back(ins, n);
std::size_t size = ins->get_shape().bytes();
n += size + (size % 4);
}
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs)
{
auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(ins, hip_load{s, offset}, mem);
}
}
} // namespace gpu
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
namespace gpu {
struct eliminate_allocation
{
std::string name() const { return "eliminate_allocation"; }
void apply(program& p) const;
};
} // namespace gpu
} // namespace migraph
#endif
......@@ -28,6 +28,22 @@ struct hip_allocate
}
};
struct hip_load
{
shape s;
std::size_t offset = 0;
std::string name() const { return "hip::load"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(1);
return s;
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
return {s, args[0].data() + offset};
}
};
struct hip_write
{
std::string name() const { return "hip::write"; }
......
......@@ -3,6 +3,7 @@
#include <migraph/gpu/write_literals.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/eliminate_workspace.hpp>
#include <migraph/gpu/eliminate_allocation.hpp>
#include <migraph/check_context.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
......@@ -27,6 +28,7 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
eliminate_contiguous{},
dead_code_elimination{},
write_literals{},
eliminate_allocation{},
check_context<context>{},
dead_code_elimination{}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment