Commit 57444235 authored by Khalique's avatar Khalique
Browse files

fix merge conflict

parents a0ea12f6 d8bf45cf
...@@ -4,6 +4,12 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}") ...@@ -4,6 +4,12 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}")
message(FATAL_ERROR "The binary and source directroy cannot be the same") message(FATAL_ERROR "The binary and source directroy cannot be the same")
endif() endif()
# This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE )
set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." )
endif()
project(migraphlib) project(migraphlib)
find_package(ROCM REQUIRED) find_package(ROCM REQUIRED)
......
add_library(migraph add_library(migraph
auto_contiguous.cpp auto_contiguous.cpp
common_subexpression_elimination.cpp
constant_propagate.cpp constant_propagate.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
......
#include <migraph/common_subexpression_elimination.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/functional.hpp>
#include <unordered_set>
namespace migraph {
template <class Range>
void cse_range(program& p, Range&& r)
{
std::unordered_multimap<std::string, instruction_ref> instructions;
for(auto ins : r)
{
// Skip dead instructions
if(ins->outputs().empty())
continue;
// Find instruction with the same name
auto found_instructions = range(instructions.equal_range(ins->name()));
for(const auto& pp : found_instructions)
{
auto eq = pp.second;
if(*eq != *ins)
continue;
p.replace_instruction(ins, eq);
cse_range(p, eq->outputs());
}
instructions.emplace(ins->name(), ins);
}
}
void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); }
} // namespace migraph
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/errors.hpp> #include <migraph/errors.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/reflect.hpp>
namespace migraph { namespace migraph {
...@@ -22,6 +23,13 @@ struct literal ...@@ -22,6 +23,13 @@ struct literal
struct outline struct outline
{ {
shape s; shape s;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
}
std::string name() const { return "@outline"; } std::string name() const { return "@outline"; }
shape compute_shape(const std::vector<shape>&) const { return s; } shape compute_shape(const std::vector<shape>&) const { return s; }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
...@@ -33,6 +41,13 @@ struct outline ...@@ -33,6 +41,13 @@ struct outline
struct param struct param
{ {
std::string parameter; std::string parameter;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.parameter, "parameter"));
}
std::string name() const { return "@param"; } std::string name() const { return "@param"; }
shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); } shape compute_shape(const std::vector<shape>&) const { MIGRAPH_THROW("builtin"); }
argument compute(context&, const shape&, const std::vector<argument>&) const argument compute(context&, const shape&, const std::vector<argument>&) const
......
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct common_subexpression_elimination
{
std::string name() const { return "common_subexpression_elimination"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
...@@ -43,6 +44,10 @@ struct instruction ...@@ -43,6 +44,10 @@ struct instruction
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y);
friend bool operator!=(const instruction& x, const instruction& y);
friend bool operator==(instruction_ref ref, const instruction& i); friend bool operator==(instruction_ref ref, const instruction& i);
friend bool operator!=(const instruction& i, instruction_ref ref); friend bool operator!=(const instruction& i, instruction_ref ref);
...@@ -52,7 +57,10 @@ struct instruction ...@@ -52,7 +57,10 @@ struct instruction
void add_output(instruction_ref ins); void add_output(instruction_ref ins);
template <class T> template <class T>
void remove_output(const T& ins); void remove_output(const T& ins)
{
migraph::erase(output, ins);
}
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
......
...@@ -314,6 +314,57 @@ struct contiguous ...@@ -314,6 +314,57 @@ struct contiguous
} }
}; };
struct concat
{
std::size_t axis = 0;
std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[axis] += arg.get_shape().lens()[axis];
}
return offsets;
}
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
{
MIGRAPH_THROW("Number of input tensors should exceed 0");
}
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{
if(l != axis)
{
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
}))
{
MIGRAPH_THROW("Non-axis dimensions should match");
}
}
}
std::size_t new_dim_axis = 0;
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return {type, new_lens};
}
};
struct slice struct slice
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -531,7 +582,7 @@ struct reshape ...@@ -531,7 +582,7 @@ struct reshape
} }
}; };
struct gemm struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 0.0;
...@@ -542,7 +593,7 @@ struct gemm ...@@ -542,7 +593,7 @@ struct gemm
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::string name() const { return "gemm"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type(); check_shapes{inputs, *this}.has(2).same_type();
......
...@@ -95,6 +95,10 @@ struct program ...@@ -95,6 +95,10 @@ struct program
void perf_report(std::ostream& os, std::size_t n, parameter_map params) const; void perf_report(std::ostream& os, std::size_t n, parameter_map params) const;
void debug_print();
void debug_print(instruction_ref ins);
void debug_print(const std::vector<instruction_ref>& inss);
friend std::ostream& operator<<(std::ostream& os, const program& p); friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
......
...@@ -92,6 +92,12 @@ iterator_range<Iterator> range(Iterator start, Iterator last) ...@@ -92,6 +92,12 @@ iterator_range<Iterator> range(Iterator start, Iterator last)
return {start, last}; return {start, last};
} }
template <class Iterator>
iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
{
return {p.first, p.second};
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -94,6 +94,17 @@ const std::vector<instruction_ref>& instruction::inputs() const { return argumen ...@@ -94,6 +94,17 @@ const std::vector<instruction_ref>& instruction::inputs() const { return argumen
const std::vector<instruction_ref>& instruction::outputs() const { return output; } const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(const instruction& x, const instruction& y)
{
if(not(x.result == y.result and x.op == y.op and x.arguments == y.arguments))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
return true;
}
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
...@@ -106,12 +117,6 @@ void instruction::add_output(instruction_ref ins) ...@@ -106,12 +117,6 @@ void instruction::add_output(instruction_ref ins)
output.push_back(ins); output.push_back(ins);
} }
template <class T>
void instruction::remove_output(const T& ins)
{
migraph::erase(output, ins);
}
void instruction::backreference(instruction_ref ref) void instruction::backreference(instruction_ref ref)
{ {
for(auto&& arg : ref->inputs()) for(auto&& arg : ref->inputs())
...@@ -151,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args) ...@@ -151,6 +156,7 @@ void instruction::replace(std::vector<instruction_ref> args)
void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
{ {
assert(std::any_of(arguments.begin(), arguments.end(), [&](auto i) { return i == old; }));
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this); old->remove_output(*this);
} }
......
...@@ -50,7 +50,7 @@ struct onnx_parser ...@@ -50,7 +50,7 @@ struct onnx_parser
{ {
add_generic_op("Add", op::add{}); add_generic_op("Add", op::add{});
add_generic_op("Div", op::div{}); add_generic_op("Div", op::div{});
add_generic_op("MatMul", op::gemm{}); add_generic_op("MatMul", op::dot{});
add_generic_op("Mul", op::mul{}); add_generic_op("Mul", op::mul{});
add_generic_op("Relu", op::activation{"relu"}); add_generic_op("Relu", op::activation{"relu"});
add_generic_op("Sub", op::sub{}); add_generic_op("Sub", op::sub{});
...@@ -67,6 +67,10 @@ struct onnx_parser ...@@ -67,6 +67,10 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat);
} }
template <class F> template <class F>
...@@ -188,6 +192,52 @@ struct onnx_parser ...@@ -188,6 +192,52 @@ struct onnx_parser
return prog.add_instruction(op::flatten{axis}, args[0]); return prog.add_instruction(op::flatten{axis}, args[0]);
} }
instruction_ref
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::size_t axis = parse_value(attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::slice op;
if(contains(attributes, "axes"))
{
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
{
literal s = parse_value(attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
}
{
literal s = parse_value(attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_constant(const std::string&, instruction_ref parse_constant(const std::string&,
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
...@@ -225,11 +275,11 @@ struct onnx_parser ...@@ -225,11 +275,11 @@ struct onnx_parser
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = prog.add_instruction(op::gemm{alpha, beta}, l1, l2); auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]); auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l3, l4); return prog.add_instruction(op::add{}, l3, l4);
} }
return prog.add_instruction(op::gemm{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
instruction_ref instruction_ref
......
...@@ -23,21 +23,11 @@ struct program_impl ...@@ -23,21 +23,11 @@ struct program_impl
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
template <class F> static void print_instruction(std::ostream& os,
static void print_program(std::ostream& os, const program& p, F annonate) instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{ {
std::unordered_map<instruction_ref, std::string> names; os << names.at(ins) << " = ";
int count = 0;
for(auto ins : iterator_for(p))
{
std::string var_name = "@" + std::to_string(count);
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
os << var_name << " = ";
os << ins->get_operator(); os << ins->get_operator();
...@@ -54,7 +44,6 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -54,7 +44,6 @@ static void print_program(std::ostream& os, const program& p, F annonate)
char delim = '('; char delim = '(';
for(auto&& arg : ins->inputs()) for(auto&& arg : ins->inputs())
{ {
assert(p.has_instruction(arg) && "Instruction not found");
os << delim << names.at(arg); os << delim << names.at(arg);
delim = ','; delim = ',';
} }
...@@ -62,12 +51,36 @@ static void print_program(std::ostream& os, const program& p, F annonate) ...@@ -62,12 +51,36 @@ static void print_program(std::ostream& os, const program& p, F annonate)
} }
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
}
template <class F>
static void print_program(std::ostream& os, const program& p, F annonate)
{
std::unordered_map<instruction_ref, std::string> names;
int count = 0;
for(auto ins : iterator_for(p))
{
std::string var_name = "@" + std::to_string(count);
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
names.emplace(ins, var_name);
// TODO: Use all_of
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
(void)arg;
}
print_instruction(os, ins, names);
annonate(ins, names); annonate(ins, names);
os << std::endl; os << std::endl;
names.emplace(ins, var_name);
count++; count++;
} }
} }
...@@ -124,7 +137,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -124,7 +137,9 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
{ {
return rep; return rep;
} }
for(auto&& out : ins->outputs()) // Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
{ {
// TODO: Check for possible cycles // TODO: Check for possible cycles
if(out != rep) if(out != rep)
...@@ -135,6 +150,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -135,6 +150,10 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
} }
// Replacement should not be dead code unless its the last instruction // Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end())); assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
[&](auto i) { return i == rep; }));
assert(ins->valid(begin())); assert(ins->valid(begin()));
assert(rep->valid(begin())); assert(rep->valid(begin()));
return rep; return rep;
...@@ -449,6 +468,25 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -449,6 +468,25 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
<< ", " << std::round(calculate_overhead_percent) << "%" << std::endl; << ", " << std::round(calculate_overhead_percent) << "%" << std::endl;
} }
void program::debug_print() { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins)
{
std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
std::cout << std::endl;
}
});
}
void program::debug_print(const std::vector<instruction_ref>& inss)
{
for(auto ins : inss)
debug_print(ins);
std::cout << std::endl;
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
......
...@@ -282,10 +282,38 @@ struct cpu_contiguous ...@@ -282,10 +282,38 @@ struct cpu_contiguous
} }
}; };
struct cpu_concat
{
op::concat op;
std::string name() const { return "cpu::concat"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
});
}
return result;
}
};
struct cpu_gemm struct cpu_gemm
{ {
op::gemm op; op::dot op;
std::string name() const { return "cpu::gemm"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
...@@ -564,12 +592,12 @@ struct cpu_apply ...@@ -564,12 +592,12 @@ struct cpu_apply
{ {
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["gemm"] = extend_op<cpu_gemm, op::gemm>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>(); apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>(); apply_map["sigmoid"] = simple_op<cpu_unary<sigmoid_op>>();
...@@ -581,7 +609,6 @@ struct cpu_apply ...@@ -581,7 +609,6 @@ struct cpu_apply
apply_map["add"] = simple_op<cpu_binary<add_op>>(); apply_map["add"] = simple_op<cpu_binary<add_op>>();
apply_map["sub"] = simple_op<cpu_binary<sub_op>>(); apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>(); apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
// apply_map["scalar"] = simple_op<cpu_binary<mul_op>>();
apply_map["div"] = simple_op<cpu_binary<div_op>>(); apply_map["div"] = simple_op<cpu_binary<div_op>>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = simple_op<softmax2d>();
......
...@@ -15,6 +15,7 @@ add_library(migraph_device ...@@ -15,6 +15,7 @@ add_library(migraph_device
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp
) )
rocm_clang_tidy_check(migraph_device) rocm_clang_tidy_check(migraph_device)
target_link_libraries(migraph_device migraph hip::device) target_link_libraries(migraph_device migraph hip::device)
...@@ -32,6 +33,7 @@ add_library(migraph_gpu ...@@ -32,6 +33,7 @@ add_library(migraph_gpu
convolution.cpp convolution.cpp
softmax.cpp softmax.cpp
contiguous.cpp contiguous.cpp
concat.cpp
relu.cpp relu.cpp
leaky_relu.cpp leaky_relu.cpp
add.cpp add.cpp
......
...@@ -14,9 +14,9 @@ shape hip_add::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,9 +14,9 @@ shape hip_add::compute_shape(const std::vector<shape>& inputs) const
return inputs.at(0); return inputs.at(0);
} }
argument hip_add::compute(context&, const shape&, const std::vector<argument>& args) const argument hip_add::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add(args[2], args[0], args[1]); device::add(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2]; return args[2];
} }
...@@ -34,7 +34,7 @@ argument miopen_add::compute(context& ctx, ...@@ -34,7 +34,7 @@ argument miopen_add::compute(context& ctx,
auto a_desc = make_tensor(args[0].get_shape()); auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape()); auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape); auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(), miopenOpTensor(ctx.get_stream().get_miopen(),
miopenTensorOpAdd, miopenTensorOpAdd,
&alpha, &alpha,
a_desc.get(), a_desc.get(),
......
...@@ -23,7 +23,7 @@ argument miopen_batch_norm_inference::compute(context& ctx, ...@@ -23,7 +23,7 @@ argument miopen_batch_norm_inference::compute(context& ctx,
float alpha = 1.0, beta = 0.0f; float alpha = 1.0, beta = 0.0f;
miopenBatchNormalizationForwardInference(ctx.handle.get(), miopenBatchNormalizationForwardInference(ctx.get_stream().get_miopen(),
miopenBatchNormMode_t(op.bn_mode), miopenBatchNormMode_t(op.bn_mode),
&alpha, &alpha,
&beta, &beta,
......
#include <migraph/gpu/concat.hpp>
#include <migraph/operators.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/device/concat.hpp>
#include <utility>
namespace migraph {
namespace gpu {
shape hip_concat::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument hip_concat::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args);
return device::concat(ctx.get_stream().get(), output_shape, args, offsets);
}
} // namespace gpu
} // namespace migraph
...@@ -12,13 +12,14 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const ...@@ -12,13 +12,14 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument argument miopen_contiguous::compute(context& ctx,
miopen_contiguous::compute(context&, shape output_shape, const std::vector<argument>& args) const shape output_shape,
const std::vector<argument>& args) const
{ {
assert(output_shape == args[1].get_shape()); assert(output_shape == args[1].get_shape());
assert(output_shape.standard()); assert(output_shape.standard());
(void)output_shape; (void)output_shape;
device::contiguous(args.at(1), args.at(0)); device::contiguous(ctx.get_stream().get(), args.at(1), args.at(0));
return args.at(1); return args.at(1);
} }
......
...@@ -21,7 +21,7 @@ argument miopen_convolution::compute(context& ctx, ...@@ -21,7 +21,7 @@ argument miopen_convolution::compute(context& ctx,
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
miopenConvolutionForward(ctx.handle.get(), miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[0].implicit(), args[0].implicit(),
...@@ -47,18 +47,22 @@ shape miopen_convolution::compile(context& ctx, ...@@ -47,18 +47,22 @@ shape miopen_convolution::compile(context& ctx,
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize( miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
ctx.handle.get(), w_desc.get(), x_desc.get(), cd.get(), y_desc.get(), &workspace_size); w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]->get_shape())); auto x = to_gpu(generate_argument(inputs[0]->get_shape()));
auto w = to_gpu(generate_argument(inputs[1]->get_shape())); auto w = to_gpu(generate_argument(inputs[1]->get_shape()));
auto y = to_gpu(generate_argument(output_shape)); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1; int algo_count = 1;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(), miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(), x_desc.get(),
x.implicit(), x.implicit(),
w_desc.get(), w_desc.get(),
......
...@@ -5,14 +5,18 @@ namespace migraph { ...@@ -5,14 +5,18 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add(const argument& result, const argument& arg1, const argument& arg2) void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(result, arg1, arg2)([](auto x, auto y) { return x + y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x + y; });
} }
void add(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) void add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
nary(result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; }); nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; });
} }
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment