Commit 76f68df4 authored by wsttiger's avatar wsttiger
Browse files

Merged from master

parents dc0c4810 8ae3ffea
......@@ -26,6 +26,20 @@ struct hip_allocate
}
};
struct hip_write
{
std::string name() const { return "hip::write"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.front();
}
argument compute(context&, shape, std::vector<argument> args) const
{
return to_gpu(args.front());
}
};
} // namespace miopen
} // namespace migraph
......
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_LOWERING_HPP
#include <migraph/program.hpp>
namespace migraph {
namespace miopen {
struct lowering
{
std::string name() const { return "miopen::lowering"; }
void apply(program& p) const;
};
} // namespace miopen
} // namespace migraph
#endif
......@@ -2,7 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <miopen/miopen.h>
namespace migraph {
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraph/manage_ptr.hpp>
#include <migraph/operators.hpp>
#include <rocblas.h>
namespace migraph {
namespace miopen {
using rocblas_handle_ptr = MIGRAPH_MANAGE_PTR(rocblas_handle, rocblas_destroy_handle);
rocblas_handle_ptr create_rocblas_handle_ptr();
} // namespace miopen
} // namespace migraph
#endif
......@@ -6,11 +6,11 @@
namespace migraph {
namespace miopen {
struct miopen_target
struct target
{
std::string name() const;
std::vector<pass> get_passes(context& ctx) const;
context get_context() const;
std::vector<pass> get_passes(migraph::context& ctx) const;
migraph::context get_context() const;
};
} // namespace miopen
......
#ifndef MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#define MIGRAPH_GUARD_RTGLIB_MIOPEN_WRITE_LITERALS_HPP
#include <migraph/program.hpp>
namespace migraph {
namespace miopen {
struct write_literals
{
std::string name() const { return "miopen::write_literals"; }
void apply(program& p) const;
};
} // namespace miopen
} // namespace migraph
#endif
#include <migraph/miopen/miopen_target.hpp>
#include <rocblas.h>
#include <migraph/miopen/lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
......@@ -7,15 +8,13 @@
#include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/miopen/kernels.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace migraph {
namespace miopen {
struct miopen_context
{
shared<miopen_handle> handle;
};
struct miopen_convolution
{
convolution op;
......@@ -27,9 +26,8 @@ struct miopen_convolution
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto w_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape);
......@@ -79,9 +77,8 @@ struct miopen_pooling
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(1)});
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
......@@ -112,7 +109,7 @@ struct miopen_add
return inputs.at(0);
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
if(args[1].get_shape().broadcasted())
{
......@@ -129,7 +126,6 @@ struct miopen_add
}
else
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
......@@ -159,18 +155,31 @@ struct miopen_gemm
check_shapes{inputs, *this}.has(3);
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, from_gpu(args[0]), from_gpu(args[1]))(
[&](auto output, auto input1, auto input2) {
dfor(input1.get_shape().lens()[0],
input2.get_shape().lens()[1],
input2.get_shape().lens()[0])(
[&](auto i, auto j, auto k) { output(i, j) += input1(i, k) * input2(k, j); });
});
return to_gpu(result);
float alpha = 1.0f;
float beta = 0.0f;
rocblas_int lda = args[0].get_shape().lens()[1];
rocblas_int ldb = args[1].get_shape().lens()[1];
rocblas_int ldc = args[2].get_shape().lens()[1];
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.rbhandle.get(),
rocblas_operation_none,
rocblas_operation_none,
n,
m,
k,
&alpha,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta,
args[2].implicit(),
ldc);
return args[2];
}
};
......@@ -216,9 +225,8 @@ struct miopen_relu
return inputs.at(1);
}
argument compute(context& gctx, shape output_shape, std::vector<argument> args) const
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{
auto& ctx = any_cast<miopen_context>(gctx);
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
......@@ -241,7 +249,7 @@ struct miopen_apply
void apply()
{
prog->insert_instruction(prog->begin(), check_context<miopen_context>{});
prog->insert_instruction(prog->begin(), check_context<context>{});
for(auto it = prog->begin(); it != prog->end(); it++)
{
if(it->op.name() == "convolution")
......@@ -354,21 +362,7 @@ struct miopen_apply
}
};
struct miopen_pass
{
std::string name() const { return "miopen::pass"; }
void apply(program& p) const { miopen_apply{&p}.apply(); }
};
std::vector<pass> miopen_target::get_passes(context&) const { return {miopen_pass{}}; }
std::string miopen_target::name() const { return "miopen"; }
context miopen_target::get_context() const
{
return miopen_context{share(make_obj<miopen_handle>(&miopenCreate))};
}
void lowering::apply(program& p) const { miopen_apply{&p}.apply(); }
} // namespace miopen
......
#include <migraph/miopen/rocblas.hpp>
namespace migraph {
namespace miopen {
rocblas_handle_ptr create_rocblas_handle_ptr()
{
rocblas_handle handle;
rocblas_create_handle(&handle);
return rocblas_handle_ptr{handle};
}
} // namespace miopen
} // namespace migraph
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/lowering.hpp>
#include <migraph/miopen/write_literals.hpp>
#include <migraph/miopen/context.hpp>
namespace migraph {
namespace miopen {
std::vector<pass> target::get_passes(migraph::context&) const
{
return {lowering{}, write_literals{}};
}
std::string target::name() const { return "miopen"; }
migraph::context target::get_context() const
{
return context{share(make_obj<miopen_handle>(&miopenCreate)),
share(create_rocblas_handle_ptr())};
}
} // namespace miopen
} // namespace migraph
#include <migraph/miopen/write_literals.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/instruction.hpp>
namespace migraph {
namespace miopen {
void write_literals::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->op.name() == "@literal")
{
literal l = ins->lit;
auto pre = p.add_literal(l);
p.replace_instruction(ins, hip_write{}, pre);
}
}
}
} // namespace miopen
} // namespace migraph
......@@ -6,6 +6,25 @@
#include "test.hpp"
#include "verify.hpp"
void batch_norm_inference_test()
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {4}};
auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}});
auto gamma = p.add_literal(migraph::literal{s, {1}});
auto beta = p.add_literal(migraph::literal{s, {0}});
auto mean = p.add_literal(migraph::literal{s, {0}});
auto variance = p.add_literal(migraph::literal{s, {1}});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, gamma, beta);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> result_vector(4);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1 / (1 + 1.0e-6), 2 / (1 + 1.0e-6), 3 / (1 + 1.0e-6), 4 / (1 + 1.0e-6)};
EXPECT(test::verify_range(result_vector, gold));
}
void exp_test()
{
migraph::program p;
......@@ -252,6 +271,63 @@ void gemm_test()
}
}
void maxpool_test()
{
migraph::program p;
std::vector<float> a = {
-2.1314404, -1.63041711, 1.54562736, 1.04625261, -1.42931843, -0.48703974, 0.4065806,
-0.1524526, 1.30775225, 0.45538983, -0.06631992, -1.75332725, 1.33493888, 0.47327688,
0.36873096, 1.18358743, -0.34640595, 1.22098756, 0.01946825, -0.20238149, 0.43348005,
-0.67991608, -0.83041084, 0.93537551, 0.70241445, -0.5654031, -1.30899191, -0.26735824,
-0.52444768, 1.99097753, 1.86504853, -0.26506025, 0.26236168, 0.43763575, 0.95300823,
-1.02733946, -0.74655169, -0.5374338, -0.28901565, -0.59789604, 0.5310151, 0.99125904,
0.40609556, -1.57175648, 0.22031412, 1.45862222, 0.53217483, 1.39087725, 1.00170159,
-0.87175864, -1.7204628, -1.72008383, -0.38656762, -0.01443311, 1.46645272, -1.39995027,
0.22505587, -0.43461126, -0.05511411, -0.79950953, -0.01439556, 0.08795211, 1.18943918,
-0.84079367, -1.73383629, -0.55662078, -0.30626822, -0.67339015, 0.44179603, 0.54316711,
0.40899998, -0.27831686, -1.11900508, -0.0881724, 0.35483059, 2.36277103, -0.04765317,
-0.36865309, 0.73814237, 1.47151589, 1.36546791, -0.32649881, -1.0517807, 2.24768877,
0.68883753, 0.58646208, -0.91017133, -0.50462508, -0.4013325, -0.72348958, -0.47368807,
0.35285577, -1.01817429, -0.5152272, 0.60321307, 0.43521205, -0.23733577, 0.66427642,
0.82949388, 0.82443929, 0.71550399, 0.34561086, 0.68570769, -0.40718508, -1.20350206,
0.15793853, -2.31013632, -0.07934658, -0.09348056, 0.36576006, 2.46601582, 0.11090943,
0.9144392, 0.56759721, -0.22112127, -0.21955389, 0.72474903, -1.28448462, 1.53285873,
0.37437943, 0.31409341, 1.95433736, 0.91620457, 0.86205518, 1.24365854, 0.19248386,
0.22526583, 0.13462132, -0.27561715, -2.06446075, -0.02306402, -1.38278747, 1.1411345,
1.31293464, -1.86041689, 1.06763375, -0.26541466, 1.4545635, 1.11430049, -0.66491818,
0.87101674, 0.67768967, -1.02062869, -1.05031872, -2.2764678, -2.0200038, 0.37592548,
-0.26701379, -0.83388507, 0.19403623, 1.00968623, 0.11020003, 1.16736257, -1.1160326,
0.47346735, 0.6126079, -0.19135755, 1.33624589, -0.29802522, -0.57873946, -1.06555879,
-0.20686582, 1.36892557, -0.19937795, 0.8649236, -1.40126073, 1.53441942, 0.34682792,
-1.31724346, -1.32898355, 2.40126371, 0.07845283, 1.35732043, -0.63678312, 0.39429256,
-1.36487007, -0.31026676, -0.44981545, -0.28994772, -0.14657612, -1.75206447, -0.70612341,
1.20071781, -1.64647579, -0.7133292, 0.88494766, 0.52119428, -2.77387547, 2.07681108,
-0.90133125, 0.2847338, 0.6174528, -0.20616426, -0.64263535, -1.08496261, 0.54275119,
-0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746,
-0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223,
-0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682};
std::vector<float> c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraph::shape a_shape{migraph::shape::float_type, {2, 3, 6, 6}};
auto al = p.add_literal(migraph::literal{a_shape, a});
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::cout << result.get_shape() << std::endl;
std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
for(int i = 0; i < results_vector.size(); i++)
{
// std::cout << results_vector[i] << " " << c[i] << std::endl;
EXPECT(std::abs(results_vector[i] - c[i]) < tol);
}
}
void softmax_test()
{
migraph::program p;
......@@ -564,7 +640,9 @@ int main()
transpose_test();
contiguous_test();
softmax_test();
// maxpool_test();
conv2d_test();
conv2d_padding_test();
conv2d_padding_stride_test();
batch_norm_inference_test();
}
......@@ -3,7 +3,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/miopen/miopen_target.hpp>
#include <migraph/miopen/target.hpp>
#include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/manage_ptr.hpp>
......@@ -27,7 +27,7 @@ migraph::argument run_gpu()
{
V v;
auto p = v.create_program();
p.compile(migraph::miopen::miopen_target{});
p.compile(migraph::miopen::target{});
auto m = v.create_params();
for(auto&& e : m)
......@@ -49,6 +49,23 @@ void verify_program()
visit_all(cpu_arg, gpu_arg)([](auto cpu, auto gpu) { EXPECT(test::verify_range(cpu, gpu)); });
}
struct test_literals
{
migraph::program create_program() const
{
migraph::program p;
auto input = p.add_literal(
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto weights = p.add_literal(
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraph::convolution{}, input, weights);
p.add_instruction(migraph::activation{"relu"}, conv);
return p;
}
migraph::program::parameter_map create_params() const { return {}; }
};
struct test_add
{
migraph::program create_program() const
......
#ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
<%
......
......@@ -9,6 +9,7 @@
#include <migraph/shape.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
namespace migraph {
......@@ -22,11 +23,17 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
} // namespace operation_stream
template <class T>
argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<argument> input)
{
return x.compute(auto_any_cast(ctx), output_shape, input);
}
<%
interface('operation',
virtual('name', returns='std::string', const=True),
virtual('compute_shape', returns='shape', input='std::vector<shape>', const=True),
virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True),
virtual('compute', returns='argument', ctx='context&', output='shape', input='std::vector<argument>', const=True, default='compute_op'),
friend('operator<<', returns='std::ostream &', os='std::ostream &', op='const operation &', using='migraph::operation_stream::operator<<')
)
%>
......
......@@ -213,16 +213,21 @@ def internal_name(name):
else:
return name
def generate_call(m, friend):
def generate_call(m, friend, indirect):
if m['name'].startswith('operator'):
op = m['name'][8:]
args = m['args']
if ',' in args:
return args.replace(',', op)
else:
return string.Template('${op}${arga}').substitute(op=op, args=args)
return string.Template('${op}${args}').substitute(op=op, args=args)
if friend:
return string.Template('${name}(${args})').substitute(m)
if indirect:
if m['args']:
return string.Template('${default}(private_detail_te_value, ${args})').substitute(m)
else:
return string.Template('${default}(private_detail_te_value)').substitute(m)
return string.Template('private_detail_te_value.${name}(${args})').substitute(m)
def convert_member(d, struct_name):
......@@ -242,9 +247,12 @@ def convert_member(d, struct_name):
member_params = []
skip = False
friend = False
indirect = False
if 'friend' in d[name]:
friend = True
skip = True
if 'default' in d[name]:
indirect = True
for x in d[name]:
t = d[name][x]
if x == 'return':
......@@ -254,8 +262,12 @@ def convert_member(d, struct_name):
member['member_const'] = 'const'
elif x == 'friend':
member['friend'] = 'friend'
elif x == 'default':
member['default'] = t
elif x == 'using':
member['using'] = 'using {};'.format(d[name]['using'])
elif x.startswith('__') and x.endswith('__'):
continue
else:
use_member = not(skip and struct_name == trim_type_name(t))
arg_name = x
......@@ -278,7 +290,7 @@ def convert_member(d, struct_name):
member['params'] = ','.join(params)
member['params'] = ','.join(params)
member['member_params'] = ','.join(member_params)
member['call'] = generate_call(member, friend)
member['call'] = generate_call(member, friend, indirect)
return member
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment