"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "49f7b40139a052003863c0326c4a1509084c1b56"
Commit 06cc4f8f authored by wsttiger's avatar wsttiger
Browse files

did merge from master

parents 7009dc1d 8b6a35bb
...@@ -437,17 +437,23 @@ struct flatten ...@@ -437,17 +437,23 @@ struct flatten
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto&& lens = inputs.front().lens();
if(axis == 0) if(axis == 0)
{ {
return {inputs.at(0).type(), {1, inputs.at(0).elements()}}; return {inputs.at(0).type(), {1, inputs.at(0).elements()}};
} }
if(axis == 1) else if(axis < lens.size())
{ {
return {inputs.at(0).type(), {inputs.at(0).elements(), 1}}; auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
} }
else else
{ {
MIGRAPH_THROW("axis can only be either 0 or 1"); MIGRAPH_THROW("axis for flatten must be less than tensor rank");
} }
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
...@@ -455,7 +461,6 @@ struct flatten ...@@ -455,7 +461,6 @@ struct flatten
return {output_shape, std::move(args.front().data)}; return {output_shape, std::move(args.front().data)};
} }
}; };
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
namespace test { namespace migraph {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -65,12 +65,6 @@ struct not_finite_fn ...@@ -65,12 +65,6 @@ struct not_finite_fn
}; };
static constexpr not_finite_fn not_finite{}; static constexpr not_finite_fn not_finite{};
template <class T, class U>
T as(T, U x)
{
return x;
}
struct compare_mag_fn struct compare_mag_fn
{ {
template <class T, class U> template <class T, class U>
...@@ -172,5 +166,5 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80) ...@@ -172,5 +166,5 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80)
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
return error <= threshold; return error <= threshold;
} }
} // namespace test } // namespace migraph
#endif #endif
...@@ -60,8 +60,8 @@ struct onnx_parser ...@@ -60,8 +60,8 @@ struct onnx_parser
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_max_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_average_pooling); add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape); add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
...@@ -129,9 +129,9 @@ struct onnx_parser ...@@ -129,9 +129,9 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_max_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args) parse_pooling(std::string name, attribute_map attributes, std::vector<instruction_ref> args)
{ {
pooling op{"max"}; pooling op{name == "MaxPool" ? "max" : "average"};
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
copy(attributes["pads"].ints(), op.padding.begin()); copy(attributes["pads"].ints(), op.padding.begin());
...@@ -187,10 +187,10 @@ struct onnx_parser ...@@ -187,10 +187,10 @@ struct onnx_parser
parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args) parse_flatten(std::string, attribute_map attributes, std::vector<instruction_ref> args)
{ {
uint64_t axis = 0; uint64_t axis = 0;
// if(contains(attributes, "axis")) if(contains(attributes, "axis"))
// { {
// axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
// } }
return prog.add_instruction(flatten{axis}, args[0]); return prog.add_instruction(flatten{axis}, args[0]);
} }
......
...@@ -78,7 +78,7 @@ int main(int argc, char const* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char const* argv[])
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax(logits);
for(auto x : logits) for(auto x : probs)
std::cout << x << " "; std::cout << x << " ";
std::cout << std::endl; std::cout << std::endl;
std::cout << std::endl; std::cout << std::endl;
......
...@@ -5,16 +5,18 @@ ...@@ -5,16 +5,18 @@
#include <migraph/gpu/target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
#include <miopen/miopen.h> #include <migraph/verify.hpp>
#include <migraph/gpu/miopen.hpp>
migraph::argument run_cpu(std::string file) migraph::argument run_cpu(std::string file)
{ {
auto p = migraph::parse_onnx(file); auto p = migraph::parse_onnx(file);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto s = p.get_parameter_shape("Input3"); migraph::program::parameter_map m;
auto input3 = migraph::generate_argument(s); for(auto&& x : p.get_parameter_shapes())
auto out = p.eval({{"Input3", input3}}); {
m[x.first] = migraph::generate_argument(x.second);
}
auto out = p.eval(m);
std::cout << p << std::endl; std::cout << p << std::endl;
return out; return out;
} }
...@@ -22,14 +24,14 @@ migraph::argument run_cpu(std::string file) ...@@ -22,14 +24,14 @@ migraph::argument run_cpu(std::string file)
migraph::argument run_gpu(std::string file) migraph::argument run_gpu(std::string file)
{ {
auto p = migraph::parse_onnx(file); auto p = migraph::parse_onnx(file);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::gpu::target{});
auto s = p.get_parameter_shape("Input3");
auto input3 = migraph::gpu::to_gpu(migraph::generate_argument(s));
auto output = migraph::gpu::to_gpu(migraph::generate_argument(p.get_parameter_shape("output")));
auto handle = migraph::gpu::make_obj<migraph::gpu::miopen_handle>(&miopenCreate);
auto out = p.eval({{"Input3", input3}, {"output", output}}); migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto out = migraph::gpu::from_gpu(p.eval(m));
std::cout << p << std::endl; std::cout << p << std::endl;
return migraph::gpu::from_gpu(out); return migraph::gpu::from_gpu(out);
} }
...@@ -41,15 +43,18 @@ int main(int argc, char const* argv[]) ...@@ -41,15 +43,18 @@ int main(int argc, char const* argv[])
std::string file = argv[1]; std::string file = argv[1];
auto x = run_cpu(file); auto x = run_cpu(file);
auto y = run_gpu(file); auto y = run_gpu(file);
if(x == y) visit_all(x, y)([](auto cpu, auto gpu) {
{ if(migraph::verify_range(cpu, gpu))
std::cout << "Passed" << std::endl; {
} std::cout << "Passed" << std::endl;
else }
{ else
std::cout << "Not equal" << std::endl; {
std::cout << x << std::endl; std::cout << "Not equal" << std::endl;
std::cout << y << std::endl; std::cout << cpu << std::endl;
} std::cout << gpu << std::endl;
}
});
} }
} }
#include <migraph/cpu/cpu_target.hpp> #include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_lowering.hpp> #include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/auto_contiguous.hpp>
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
std::string cpu_target::name() const { return "cpu"; } std::string cpu_target::name() const { return "cpu"; }
std::vector<pass> cpu_target::get_passes(context&) const { return {cpu_lowering{}}; } std::vector<pass> cpu_target::get_passes(context&) const
{
return {auto_contiguous{}, cpu_lowering{}};
}
} // namespace cpu } // namespace cpu
......
...@@ -138,7 +138,7 @@ struct miopen_pooling ...@@ -138,7 +138,7 @@ struct miopen_pooling
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(1)}); return op.compute_shape({inputs.at(0)});
} }
argument compute(context& ctx, shape output_shape, std::vector<argument> args) const argument compute(context& ctx, shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/cpu/cpu_target.hpp> #include <migraph/cpu/cpu_target.hpp>
#include <migraph/verify.hpp>
#include "test.hpp" #include "test.hpp"
#include "verify.hpp"
void batch_norm_inference_test() void batch_norm_inference_test()
{ {
...@@ -43,7 +43,7 @@ void batch_norm_inference_test() ...@@ -43,7 +43,7 @@ void batch_norm_inference_test()
std::fill(gold.begin(), gold.end(), output_val); std::fill(gold.begin(), gold.end(), output_val);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(result_vector, gold)); EXPECT(migraph::verify_range(result_vector, gold));
} }
void exp_test() void exp_test()
...@@ -57,7 +57,7 @@ void exp_test() ...@@ -57,7 +57,7 @@ void exp_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f}; std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void sin_test() void sin_test()
...@@ -71,7 +71,7 @@ void sin_test() ...@@ -71,7 +71,7 @@ void sin_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f}; std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void cos_test() void cos_test()
...@@ -85,7 +85,7 @@ void cos_test() ...@@ -85,7 +85,7 @@ void cos_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f}; std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void tan_test() void tan_test()
...@@ -99,7 +99,7 @@ void tan_test() ...@@ -99,7 +99,7 @@ void tan_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f}; std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void add_test() void add_test()
...@@ -114,7 +114,7 @@ void add_test() ...@@ -114,7 +114,7 @@ void add_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void broadcast_test() void broadcast_test()
...@@ -154,7 +154,7 @@ void add_broadcast_test() ...@@ -154,7 +154,7 @@ void add_broadcast_test()
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void sub_test() void sub_test()
...@@ -169,7 +169,7 @@ void sub_test() ...@@ -169,7 +169,7 @@ void sub_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2}; std::vector<float> gold = {-2, -2, -2};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void mul_test() void mul_test()
...@@ -184,7 +184,7 @@ void mul_test() ...@@ -184,7 +184,7 @@ void mul_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1, 0, 3}; std::vector<float> gold = {-1, 0, 3};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void div_test() void div_test()
...@@ -199,7 +199,7 @@ void div_test() ...@@ -199,7 +199,7 @@ void div_test()
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.f, 0.25f, 0.25f}; std::vector<float> gold = {-1.f, 0.25f, 0.25f};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void reshape_test() void reshape_test()
...@@ -216,7 +216,7 @@ void reshape_test() ...@@ -216,7 +216,7 @@ void reshape_test()
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, data)); EXPECT(migraph::verify_range(results_vector, data));
} }
{ {
migraph::program p; migraph::program p;
...@@ -227,7 +227,7 @@ void reshape_test() ...@@ -227,7 +227,7 @@ void reshape_test()
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, data)); EXPECT(migraph::verify_range(results_vector, data));
} }
{ {
migraph::program p; migraph::program p;
...@@ -238,7 +238,7 @@ void reshape_test() ...@@ -238,7 +238,7 @@ void reshape_test()
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, data)); EXPECT(migraph::verify_range(results_vector, data));
} }
} }
...@@ -406,7 +406,7 @@ void softmax_test() ...@@ -406,7 +406,7 @@ void softmax_test()
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void conv2d_test() void conv2d_test()
...@@ -469,7 +469,7 @@ void conv2d_test() ...@@ -469,7 +469,7 @@ void conv2d_test()
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void conv2d_padding_test() void conv2d_padding_test()
...@@ -525,7 +525,7 @@ void conv2d_padding_test() ...@@ -525,7 +525,7 @@ void conv2d_padding_test()
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void conv2d_padding_stride_test() void conv2d_padding_stride_test()
...@@ -586,7 +586,7 @@ void conv2d_padding_stride_test() ...@@ -586,7 +586,7 @@ void conv2d_padding_stride_test()
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(test::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void transpose_test() void transpose_test()
...@@ -622,7 +622,7 @@ void transpose_test() ...@@ -622,7 +622,7 @@ void transpose_test()
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
} }
...@@ -643,7 +643,7 @@ void contiguous_test() ...@@ -643,7 +643,7 @@ void contiguous_test()
std::vector<size_t> new_lens = {1, 3, 2, 2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 0}; std::vector<float> gold = {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 0};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
int main() int main()
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/manage_ptr.hpp> #include <migraph/manage_ptr.hpp>
#include <migraph/type_name.hpp> #include <migraph/type_name.hpp>
#include <migraph/verify.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
...@@ -15,7 +16,6 @@ ...@@ -15,7 +16,6 @@
#include <thread> #include <thread>
#include "test.hpp" #include "test.hpp"
#include "verify.hpp"
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic push #pragma clang diagnostic push
...@@ -102,7 +102,7 @@ void verify_program() ...@@ -102,7 +102,7 @@ void verify_program()
auto cpu_arg_f = detach_async([] { return run_cpu<V>(); }); auto cpu_arg_f = detach_async([] { return run_cpu<V>(); });
auto gpu_arg = run_gpu<V>(); auto gpu_arg = run_gpu<V>();
visit_all(cpu_arg_f.get(), gpu_arg)([](auto cpu, auto gpu) { visit_all(cpu_arg_f.get(), gpu_arg)([](auto cpu, auto gpu) {
if(not test::verify_range(cpu, gpu)) if(not migraph::verify_range(cpu, gpu))
{ {
std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl; std::cout << "FAILED: " << migraph::get_type_name<V>() << std::endl;
} }
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/onnx.hpp> #include <migraph/onnx.hpp>
#include "test.hpp" #include "test.hpp"
#include "verify.hpp"
void pytorch_conv_bias_test() void pytorch_conv_bias_test()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment