"...remote_machine/remoteMachineTrainingService.ts" did not exist on "982b30b553a89762ac6054c8307ee69ed27efd5d"
Commit c83ee9f8 authored by Paul's avatar Paul
Browse files

Add mlir verification

parent e2967e04
......@@ -98,7 +98,7 @@ struct module
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
......
......@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct module;
using module_ref = module*;
using const_module_ref = const module*;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -310,7 +310,7 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
instruction_ref ins, const_module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
......
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <test.hpp>
using migraphx::trim;
......@@ -28,6 +33,67 @@ std::string encode(std::string s)
return migraphx::trim(ss.str());
}
migraphx::program create_program_from_mlir(const migraphx::module& mmlir)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto names = mmlir.get_parameter_names();
std::vector<migraphx::instruction_ref> inputs;
std::transform(names.begin(), names.end(), std::back_inserter(inputs), [&](const auto& name) {
return mm->add_parameter(name, mmlir.get_parameter_shape(name));
});
inputs.push_back(mm->add_parameter("output", mmlir.get_output_shapes().front()));
migraphx::gpu::insert_mlir(*mm, mm->end(), mmlir, inputs);
return p;
}
migraphx::parameter_map generate_params(const migraphx::program& p)
{
migraphx::parameter_map m;
std::size_t i = 0;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraphx::generate_argument(x.second, i++);
}
return m;
}
migraphx::argument run_gpu(migraphx::program p, const migraphx::parameter_map& inputs)
{
migraphx::gpu::target t;
p.compile(t);
migraphx::parameter_map m;
for(auto&& input : inputs)
{
m[input.first] = t.copy_to(input.second);
}
for(auto&& x : p.get_parameter_shapes())
{
if(m.count(x.first) == 0)
{
m[x.first] = t.allocate(x.second);
}
}
return t.copy_from(p.eval(m).front());
}
migraphx::argument run_ref(migraphx::program p, const migraphx::parameter_map& inputs)
{
p.compile(migraphx::ref::target{});
return p.eval(inputs).front();
}
bool verify_mlir(const migraphx::module& mmlir)
{
migraphx::program ref;
ref.get_main_module()->insert_module_instructions(ref.get_main_module()->end(), &mmlir);
auto inputs = generate_params(ref);
auto mlir = create_program_from_mlir(mmlir);
return migraphx::verify_args("mlir", run_ref(ref, inputs), run_gpu(mlir, inputs));
}
TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
......@@ -49,7 +115,7 @@ module {
return;
std::cout << s << std::endl;
EXPECT(encode(s) == encode(mlir_output));
auto op = migraphx::gpu::compile_mlir(m);
EXPECT(verify_mlir(m));
}
TEST_CASE(conv_add_relu)
......@@ -78,6 +144,7 @@ module {
return;
std::cout << s << std::endl;
EXPECT(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment