Commit 87cd03e0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents 46b25c33 36b01ba5
......@@ -35,7 +35,7 @@ struct shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
......@@ -188,8 +188,7 @@ struct shape
{
switch(t)
{
case tuple_type:
{
case tuple_type: {
tv();
return;
}
......
......@@ -131,16 +131,17 @@ inline std::string interpolate_string(const std::string& input,
std::string start = "${",
std::string end = "}")
{
return interpolate_string(input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
return interpolate_string(
input,
[&](auto start_it, auto last_it) {
auto key = trim({start_it, last_it});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
},
std::move(start),
std::move(end));
}
template <class Iterator>
......
......@@ -315,8 +315,7 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
if(this->key.empty())
v(null);
......@@ -325,8 +324,7 @@ struct value
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
if(this->key.empty()) \
v(this->get_##vt()); \
else \
......@@ -346,15 +344,13 @@ struct value
{
switch(this->get_type())
{
case null_type:
{
case null_type: {
std::nullptr_t null{};
v(null);
return;
}
#define MIGRAPHX_VALUE_GENERATE_CASE_VALUE(vt, cpp_type) \
case vt##_type: \
{ \
case vt##_type: { \
v(this->get_##vt()); \
return; \
}
......
......@@ -14,44 +14,36 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
switch(o.type)
{
case msgpack::type::NIL:
{
case msgpack::type::NIL: {
v = nullptr;
break;
}
case msgpack::type::BOOLEAN:
{
case msgpack::type::BOOLEAN: {
v = o.as<bool>();
break;
}
case msgpack::type::POSITIVE_INTEGER:
{
case msgpack::type::POSITIVE_INTEGER: {
v = o.as<std::uint64_t>();
break;
}
case msgpack::type::NEGATIVE_INTEGER:
{
case msgpack::type::NEGATIVE_INTEGER: {
v = o.as<std::int64_t>();
break;
}
case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64:
{
case msgpack::type::FLOAT64: {
v = o.as<double>();
break;
}
case msgpack::type::STR:
{
case msgpack::type::STR: {
v = o.as<std::string>();
break;
}
case msgpack::type::BIN:
{
case msgpack::type::BIN: {
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY:
{
case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
......@@ -60,8 +52,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::MAP:
{
case msgpack::type::MAP: {
migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size,
......@@ -71,7 +62,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
v = r;
break;
}
case msgpack::type::EXT: { MIGRAPHX_THROW("msgpack EXT type not supported.");
case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
}
}
return o;
......
......@@ -382,8 +382,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16:
{
case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
......@@ -453,7 +452,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
......
......@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Floor", "floor"},
{"Gather", "gather"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
{"Log", "log"},
{"LRN", "lrn"},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -809,11 +809,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
transform_if(
m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
......
......@@ -303,86 +303,90 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name);
m.def("parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename,
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def("load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def(
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def(
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def(
"load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
......
......@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size())
{
}
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
......
......@@ -335,7 +335,6 @@ struct find_concat_op
}
auto y = p.insert_instruction(ins, op, concats);
return {y};
};
std::vector<instruction_ref> args;
......
......@@ -316,7 +316,6 @@ struct find_nested_concat
else
args.push_back(i);
}
})(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args);
}
......
......@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs)
{
return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0)
return true;
if(s.broadcasted())
......
......@@ -133,6 +133,7 @@ add_library(migraphx_gpu
compile_hip_code_object.cpp
compile_pointwise.cpp
compile_roialign.cpp
compile_scatternd.cpp
concat.cpp
convert.cpp
convolution.cpp
......
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation
compile_scatternd(context&, const std::vector<shape>& io_shapes, const std::string& reduction)
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 1024;
options.global = compute_global(io_shapes.at(1).elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = io_shapes;
options.params += " -DREDUCTION=assign_" + reduction + "{}";
return compile_hip_code_object(scatternd_kernel, options);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -44,12 +44,13 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{
block_scan<N>(idx,
op,
init,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
input,
output);
block_scan<N>(
idx,
op,
init,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
input,
output);
}
} // namespace device
......
......@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{
switch(n)
{
case 1:
{
case 1: {
f(std::integral_constant<index_int, 1>{});
break;
}
case 2:
{
case 2: {
f(std::integral_constant<index_int, 2>{});
break;
}
case 3:
{
case 3: {
f(std::integral_constant<index_int, 3>{});
break;
}
case 4:
{
case 4: {
f(std::integral_constant<index_int, 4>{});
break;
}
case 5:
{
case 5: {
f(std::integral_constant<index_int, 5>{});
break;
}
......
......@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
// fill all output to 0 first
idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; });
block_scan<block_size>(idx,
sum{},
0,
elem_num,
[&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; },
[&](auto j, auto x) {
auto out_loc = x - 1;
if(float_equal(in_ptr[j], 0))
return;
block_scan<block_size>(
idx,
sum{},
0,
elem_num,
[&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; },
[&](auto j, auto x) {
auto out_loc = x - 1;
if(float_equal(in_ptr[j], 0))
return;
auto index = si.multi(j);
for(size_t k = 0; k < index.size(); ++k)
{
ptr[k * elem_num + out_loc] = index[k];
}
});
auto index = si.multi(j);
for(size_t k = 0; k < index.size(); ++k)
{
ptr[k * elem_num + out_loc] = index[k];
}
});
});
});
......
......@@ -24,12 +24,13 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument&
k[axis] = j;
return k;
};
block_scan<block_size>(idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
block_scan<block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
});
});
}
......
......@@ -83,13 +83,12 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v{&alpha_r};
void* beta_v{&beta_r};
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment