Commit 0db1af37 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-improve' into bert-opt2

parents ecfb0b72 6deee23b
......@@ -80,7 +80,7 @@
"outputs": [],
"source": [
"if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16 --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
]
......
......@@ -142,7 +142,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
pm->replace_return(pm->insert_instructions(last, xm, map_ins));
return inputs;
}
......
......@@ -120,9 +120,33 @@ struct module
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
......
......@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
......
......@@ -197,6 +197,62 @@ void module::assign(const module& m)
}
}
template <class Range>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
instruction_ref last;
for(instruction_ref sins : instructions)
{
last = sins;
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty() and instructions.begin() != instructions.end())
mod_outputs = {map_ins.at(last)};
return mod_outputs;
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......@@ -335,53 +391,49 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src;
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
std::vector<instruction_ref>
module::add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
return this->insert_instructions(this->end(), instructions, std::move(map_ins));
}
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
std::vector<instruction_ref>
module::add_instructions(module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), m, std::move(map_ins));
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
std::vector<instruction_ref>
module::add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), start, last, std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
}
std::vector<instruction_ref> module::insert_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
}
instruction_ref module::add_literal(literal l)
......
......@@ -27,6 +27,7 @@
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
......@@ -53,29 +54,51 @@ struct index
return blockDim.x; // NOLINT
}
#endif
template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride)
{
return (n - _c<1>) / stride + _c<1>;
}
template <class F>
__device__ void global_stride(index_int n, F f) const
template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{
const auto stride = nglobal();
for(index_int i = global; i < n; i += stride)
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1)
{
f(i);
if constexpr(stride > n)
{
if(start < n)
f(start);
}
else
{
f(start);
}
}
else
{
for(index_int i = start; i < n; i += stride)
{
f(i);
}
}
}
template <class F>
__device__ void local_stride(index_int n, F f) const
template <class F, class N>
__device__ void global_stride(N n, F f) const
{
const auto stride = nlocal();
for(index_int i = local; i < n; i += stride)
{
f(i);
}
for_stride(global, n, nglobal(), f);
}
template <class F, class N>
__device__ void local_stride(N n, F f) const
{
for_stride(local, n, nlocal(), f);
}
};
inline __device__ index make_index()
inline __device__ __attribute__((const)) index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
}
......
......@@ -186,7 +186,8 @@ __device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) {
__syncthreads();
if constexpr((Bs or ...))
__syncthreads();
f(ys...);
};
join(invoke, preload_copy<Bs>(idx, xs)...);
......
......@@ -300,6 +300,96 @@ TEST_CASE(parameter_name_order)
EXPECT(param_names == names1);
}
TEST_CASE(insert_instructions_module)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), {x1});
m1.add_instruction(migraphx::make_op("add"), {sqrt, x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.insert_instructions(sqrt, &m2, {{x2, x1}});
EXPECT(std::prev(sqrt)->name() == "sqrt");
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_module)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions(&m2, {{x2, x1}});
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_range)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions(sqrt2, m2.end(), {{x2, x1}});
EXPECT(std::any_of(
m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); }));
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
TEST_CASE(add_instructions_vector)
{
migraphx::shape s{migraphx::shape::int32_type, {1}};
migraphx::module m1("m1");
auto x1 = m1.add_parameter("x1", s);
m1.add_instruction(migraphx::make_op("sqrt"), {x1});
migraphx::module m2("m2");
auto x2 = m2.add_parameter("x2", s);
auto sqrt2 = m2.add_instruction(migraphx::make_op("sqrt"), {x2});
m1.add_instructions({sqrt2}, {{x2, x1}});
EXPECT(std::any_of(
m1.begin(), m1.end(), [&](auto&& ins) { return migraphx::contains(ins.inputs(), x1); }));
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "sqrt"; }) ==
2);
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto&& ins) { return ins.name() == "@param"; }) ==
1);
EXPECT(contains(m1.get_parameter_shapes(), "x1"));
EXPECT(not contains(m1.get_parameter_shapes(), "x2"));
}
struct check_for_pass_op
{
bool* found = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment