Commit 87cd03e0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents 46b25c33 36b01ba5
...@@ -35,7 +35,7 @@ struct shape ...@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \ m(int64_type, int64_t) \
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
...@@ -188,8 +188,7 @@ struct shape ...@@ -188,8 +188,7 @@ struct shape
{ {
switch(t) switch(t)
{ {
case tuple_type: case tuple_type: {
{
tv(); tv();
return; return;
} }
......
...@@ -131,7 +131,8 @@ inline std::string interpolate_string(const std::string& input, ...@@ -131,7 +131,8 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${", std::string start = "${",
std::string end = "}") std::string end = "}")
{ {
return interpolate_string(input, return interpolate_string(
input,
[&](auto start_it, auto last_it) { [&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it}); auto key = trim({start_it, last_it});
auto it = vars.find(key); auto it = vars.find(key);
......
...@@ -315,8 +315,7 @@ struct value ...@@ -315,8 +315,7 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
if(this->key.empty()) if(this->key.empty())
v(null); v(null);
...@@ -325,8 +324,7 @@ struct value ...@@ -325,8 +324,7 @@ struct value
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
if(this->key.empty()) \ if(this->key.empty()) \
v(this->get_##vt()); \ v(this->get_##vt()); \
else \ else \
...@@ -346,15 +344,13 @@ struct value ...@@ -346,15 +344,13 @@ struct value
{ {
switch(this->get_type()) switch(this->get_type())
{ {
case null_type: case null_type: {
{
std::nullptr_t null{}; std::nullptr_t null{};
v(null); v(null);
return; return;
} }
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \ #define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \ case vt##_type: { \
{ \
v(this->get_##vt()); \ v(this->get_##vt()); \
return; \ return; \
} }
......
...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
switch(o.type) switch(o.type)
{ {
case msgpack::type::NIL: case msgpack::type::NIL: {
{
v = nullptr; v = nullptr;
break; break;
} }
case msgpack::type::BOOLEAN: case msgpack::type::BOOLEAN: {
{
v = o.as<bool>(); v = o.as<bool>();
break; break;
} }
case msgpack::type::POSITIVE_INTEGER: case msgpack::type::POSITIVE_INTEGER: {
{
v = o.as<std::uint64_t>(); v = o.as<std::uint64_t>();
break; break;
} }
case msgpack::type::NEGATIVE_INTEGER: case msgpack::type::NEGATIVE_INTEGER: {
{
v = o.as<std::int64_t>(); v = o.as<std::int64_t>();
break; break;
} }
case msgpack::type::FLOAT32: case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64: case msgpack::type::FLOAT64: {
{
v = o.as<double>(); v = o.as<double>();
break; break;
} }
case msgpack::type::STR: case msgpack::type::STR: {
{
v = o.as<std::string>(); v = o.as<std::string>();
break; break;
} }
case msgpack::type::BIN: case msgpack::type::BIN: {
{
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: case msgpack::type::ARRAY: {
{
migraphx::value r = migraphx::value::array{}; migraphx::value r = migraphx::value::array{};
std::for_each( std::for_each(
o.via.array.ptr, o.via.array.ptr,
...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::MAP: case msgpack::type::MAP: {
{
migraphx::value r = migraphx::value::object{}; migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr, std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size, o.via.map.ptr + o.via.map.size,
...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r; v = r;
break; break;
} }
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported."); case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
} }
} }
return o; return o;
......
...@@ -382,8 +382,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -382,8 +382,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64: case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data()); return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16: {
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half; std::vector<half> data_half;
std::transform(data_uint16.begin(), std::transform(data_uint16.begin(),
...@@ -453,7 +452,8 @@ shape::type_t get_type(int dtype) ...@@ -453,7 +452,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported"); default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
} }
} }
} }
......
...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Floor", "floor"}, {"Floor", "floor"},
{"Gather", "gather"}, {"Gather", "gather"},
{"Identity", "identity"}, {"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"}, {"LeakyRelu", "leaky_relu"},
{"Log", "log"}, {"Log", "log"},
{"LRN", "lrn"}, {"LRN", "lrn"},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -809,7 +809,8 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera ...@@ -809,7 +809,8 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) { std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name(); return mod->name();
}); });
transform_if(m.begin(), transform_if(
m.begin(),
m.end(), m.end(),
out, out,
[&](auto&& pp) { return not contains(used, pp.first); }, [&](auto&& pp) { return not contains(used, pp.first); },
......
...@@ -303,15 +303,15 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -303,15 +303,15 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name); .def("name", &migraphx::operation::name);
m.def("parse_tf", m.def(
"parse_tf",
[](const std::string& filename, [](const std::string& filename,
bool is_nhwc, bool is_nhwc,
unsigned int batch_size, unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) { std::vector<std::string> output_names) {
return migraphx::parse_tf( return migraphx::parse_tf(
filename, filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
}, },
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
...@@ -320,7 +320,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -320,7 +320,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>()); py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx", m.def(
"parse_onnx",
[](const std::string& filename, [](const std::string& filename,
unsigned int default_dim_value, unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
...@@ -343,7 +344,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -343,7 +344,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer", m.def(
"parse_onnx_buffer",
[](const std::string& onnx_buffer, [](const std::string& onnx_buffer,
unsigned int default_dim_value, unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
...@@ -363,7 +365,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -363,7 +365,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false); py::arg("print_program_on_error") = false);
m.def("load", m.def(
"load",
[](const std::string& name, const std::string& format) { [](const std::string& name, const std::string& format) {
migraphx::file_options options; migraphx::file_options options;
options.format = format; options.format = format;
...@@ -373,7 +376,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -373,7 +376,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("filename"), py::arg("filename"),
py::arg("format") = "msgpack"); py::arg("format") = "msgpack");
m.def("save", m.def(
"save",
[](const migraphx::program& p, const std::string& name, const std::string& format) { [](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options; migraphx::file_options options;
options.format = format; options.format = format;
......
...@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) while(reduce_dim(shapes, n) and n < shapes.size()) {}
{
}
return n + 1; return n + 1;
} }
......
...@@ -335,7 +335,6 @@ struct find_concat_op ...@@ -335,7 +335,6 @@ struct find_concat_op
} }
auto y = p.insert_instruction(ins, op, concats); auto y = p.insert_instruction(ins, op, concats);
return {y}; return {y};
}; };
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
......
...@@ -316,7 +316,6 @@ struct find_nested_concat ...@@ -316,7 +316,6 @@ struct find_nested_concat
else else
args.push_back(i); args.push_back(i);
} }
})(ins->inputs()); })(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args); p.replace_instruction(ins, ins->get_operator(), args);
} }
......
...@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs> ...@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs) bool is_vectorizable(const Xs&... xs)
{ {
return all_of({xs...}, [](const auto& s) { return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0) if(s.standard() and (s.lens().back() % N) == 0)
return true; return true;
if(s.broadcasted()) if(s.broadcasted())
......
...@@ -133,6 +133,7 @@ add_library(migraphx_gpu ...@@ -133,6 +133,7 @@ add_library(migraphx_gpu
compile_hip_code_object.cpp compile_hip_code_object.cpp
compile_pointwise.cpp compile_pointwise.cpp
compile_roialign.cpp compile_roialign.cpp
compile_scatternd.cpp
concat.cpp concat.cpp
convert.cpp convert.cpp
convolution.cpp convolution.cpp
......
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation
compile_scatternd(context&, const std::vector<shape>& io_shapes, const std::string& reduction)
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 1024;
options.global = compute_global(io_shapes.at(1).elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = io_shapes;
options.params += " -DREDUCTION=assign_" + reduction + "{}";
return compile_hip_code_object(scatternd_kernel, options);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -44,7 +44,8 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, ...@@ -44,7 +44,8 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template <index_int N, class Op, class T, class Input, class Output> template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{ {
block_scan<N>(idx, block_scan<N>(
idx,
op, op,
init, init,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); }, [&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
......
...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{ {
switch(n) switch(n)
{ {
case 1: case 1: {
{
f(std::integral_constant<index_int, 1>{}); f(std::integral_constant<index_int, 1>{});
break; break;
} }
case 2: case 2: {
{
f(std::integral_constant<index_int, 2>{}); f(std::integral_constant<index_int, 2>{});
break; break;
} }
case 3: case 3: {
{
f(std::integral_constant<index_int, 3>{}); f(std::integral_constant<index_int, 3>{});
break; break;
} }
case 4: case 4: {
{
f(std::integral_constant<index_int, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5: case 5: {
{
f(std::integral_constant<index_int, 5>{}); f(std::integral_constant<index_int, 5>{});
break; break;
} }
......
...@@ -25,7 +25,8 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg ...@@ -25,7 +25,8 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
// fill all output to 0 first // fill all output to 0 first
idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; }); idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; });
block_scan<block_size>(idx, block_scan<block_size>(
idx,
sum{}, sum{},
0, 0,
elem_num, elem_num,
......
...@@ -24,7 +24,8 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& ...@@ -24,7 +24,8 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument&
k[axis] = j; k[axis] = j;
return k; return k;
}; };
block_scan<block_size>(idx, block_scan<block_size>(
idx,
sum{}, sum{},
0, 0,
n, n,
......
...@@ -83,13 +83,12 @@ void gemm_impl(context& ctx, ...@@ -83,13 +83,12 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode // use void pointer to select different data type if using fp32 mode
void* alpha_v{&alpha_r}; void* alpha_v = &alpha_r;
void* beta_v{&beta_r}; void* beta_v = &beta_r;
if(compute_fp32) if(compute_fp32)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment