Commit 17f4ba28 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-vector-reduce' into jit-vector-softmax

parents a8a8d868 c84154b8
......@@ -178,9 +178,12 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...);
});
return vec_reduce(block_reduce(idx,
op,
init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
});
}
......
......@@ -147,5 +147,19 @@ constexpr auto vec_packed_transform(Ts... xs)
};
}
template <class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
return x;
else
{
vec_type<T> result = x[0];
for(int i = 1; i < vec_size<T>(); i++)
result = op(result, x[i]);
return result;
}
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
......@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
__device__ __host__ auto vectorize_tensor(T x)
{
constexpr auto shape = get_shape_c<T>{};
if constexpr(shape.strides[Axis] == 0)
if constexpr(shape.lens[Axis] == 1)
return x;
else if constexpr(shape.strides[Axis] == 0)
return tensor_step<N>(x, _c<Axis>);
else
return as_vec<N>(x, _c<Axis>);
......
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
}
};
struct find_triaddlayernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -77,28 +77,28 @@ MIGRAPHX_REGISTER_OP(wait_event)
MIGRAPHX_REGISTER_OP(set_stream)
std::size_t schedule_model::concurrency() const { return streams; }
void schedule_model::sched(module& p, instruction_ref ins, std::size_t n) const
void schedule_model::sched(module& m, instruction_ref ins, std::size_t n) const
{
auto last_stream = std::find_if(std::make_reverse_iterator(ins),
std::make_reverse_iterator(p.begin()),
std::make_reverse_iterator(m.begin()),
[&](auto&& i) { return i.name() == "gpu::set_stream"; });
if(last_stream != std::make_reverse_iterator(p.begin()))
if(last_stream != std::make_reverse_iterator(m.begin()))
{
auto&& op = any_cast<set_stream>(last_stream->get_operator());
// If the same stream was set earlier then skip
if(op.stream == n)
return;
}
p.insert_instruction(ins, set_stream{n});
m.insert_instruction(ins, set_stream{n});
}
void schedule_model::wait(module& p, instruction_ref ins, std::size_t wait_id) const
void schedule_model::wait(module& m, instruction_ref ins, std::size_t wait_id) const
{
p.insert_instruction(ins, wait_event{wait_id});
m.insert_instruction(ins, wait_event{wait_id});
}
void schedule_model::record(module& p, instruction_ref ins, std::size_t wait_id) const
void schedule_model::record(module& m, instruction_ref ins, std::size_t wait_id) const
{
p.insert_instruction(std::next(ins), record_event{wait_id});
m.insert_instruction(std::next(ins), record_event{wait_id});
}
static std::unordered_map<std::string, std::size_t> create_weight_map()
......
......@@ -8,9 +8,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void sync_device::apply(module& p) const
void sync_device::apply(module& m) const
{
auto last = std::prev(p.end());
auto last = std::prev(m.end());
if(last->name() == "@return")
{
auto inputs = last->inputs();
......@@ -18,10 +18,10 @@ void sync_device::apply(module& p) const
return (i->name() == "hip::copy_from_gpu");
}))
{
auto sync_in = p.insert_instruction(last, make_op("hip::sync_stream"), inputs);
auto sync_in = m.insert_instruction(last, make_op("hip::sync_stream"), inputs);
if(not inputs.empty())
{
p.replace_instruction(inputs.front(), sync_in);
m.replace_instruction(inputs.front(), sync_in);
}
}
}
......
......@@ -31,6 +31,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
......@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{},
simplify_reshapes{},
simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
simplify_reshapes{},
propagate_constant{},
......
......@@ -11,25 +11,25 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_COPY_LITERALS)
void write_literals::apply(module& p) const
void write_literals::apply(module& m) const
{
assert(ctx != nullptr);
std::size_t n = 0;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
if(enabled(MIGRAPHX_COPY_LITERALS{}))
{
literal l = ins->get_literal();
auto pre = p.add_literal(l);
auto alloc = p.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
p.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
auto pre = m.add_literal(l);
auto alloc = m.insert_instruction(std::next(pre), hip_allocate{l.get_shape()});
m.replace_instruction(ins, hip_copy_to_gpu{}, pre, alloc);
}
else
{
std::string id = p.name() + ":@literal:" + std::to_string(n);
p.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
std::string id = m.name() + ":@literal:" + std::to_string(n);
m.replace_instruction(ins, hip_copy_literal{ins->get_literal(), id});
n++;
}
}
......
......@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module)
{
migraphx::program p;
......
......@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any2)
......@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any3)
......@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any4)
......@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_either_args_any5)
......@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")});
EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
}
TEST_CASE(match_all_of1)
......@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape())
.bind("pass");
auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape())
.bind("pass");
auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass});
}
......
......@@ -4,12 +4,20 @@
set -e
#install pip3, rocm-cmake, rocblas and miopen
apt update && apt install -y python3-pip rocm-cmake rocblas miopen-hip openmp-extras
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
# Need pip3 and Python headers to build dependencies
apt update && apt install -y python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages
pip3 install setuptools wheel
# install rbuild to build dependencies
pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
PREFIX=/usr/local
REQ_FILE_DIR=""
if [ "$#" -ge 2 ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment