"vscode:/vscode.git/clone" did not exist on "19bc41fc738686d466d52879e138b308c7cd3b3a"
Commit da4fa01f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

work in progress for multinomial. Doesn't work

parent a0fa3742
...@@ -208,6 +208,8 @@ instruction_ref insert_common_op(module& m, ...@@ -208,6 +208,8 @@ instruction_ref insert_common_op(module& m,
const operation& op, const operation& op,
std::vector<instruction_ref> inputs) std::vector<instruction_ref> inputs)
{ {
if(op.name() == "clip")
return inputs[0];
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs))); return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs)));
} }
......
...@@ -48,7 +48,7 @@ struct clip ...@@ -48,7 +48,7 @@ struct clip
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).same_type().same_dims(); check_shapes{inputs, *this, true}.has(3).same_type().same_dims();
return inputs.front(); return inputs.front();
} }
......
...@@ -47,22 +47,22 @@ struct multinomial ...@@ -47,22 +47,22 @@ struct multinomial
std::string name() const { return "multinomial"; } std::string name() const { return "multinomial"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).only_dims(2); check_shapes{inputs, *this, true}.has(2).only_dims(2);
size_t sample_size = inputs.back().lens().back(); size_t sample_size = inputs.back().max_lens().back();
if(not contains({shape::int32_type, shape::int64_type}, dtype)) if(not contains({shape::int32_type, shape::int64_type}, dtype))
MIGRAPHX_THROW( MIGRAPHX_THROW(
"Multinomial: Invalid output type. Valid types are int32_type and int64_type."); "Multinomial: Invalid output type. Valid types are int32_type and int64_type.");
return {dtype, {inputs.front().lens().front(), sample_size}}; return {dtype, {inputs.front().max_lens().front(), sample_size}};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
size_t batch_size = output_shape.lens().front(); size_t batch_size = dyn_out.computed_shape.lens().front();
size_t class_size = args[0].get_shape().lens().back(); size_t class_size = args[0].get_shape().max_lens().back();
size_t sample_size = output_shape.lens().back(); size_t sample_size = dyn_out.computed_shape.lens().back();
visit_all(args[0], args[1])([&](auto cdf, auto dist) { visit_all(args[0], args[1])([&](auto cdf, auto dist) {
result.visit([&](auto output) { result.visit([&](auto output) {
......
...@@ -60,8 +60,12 @@ struct prefix_scan_op : op_name<Derived> ...@@ -60,8 +60,12 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s = inputs.front(); auto s = inputs.front();
if(s.dynamic())
{
return {s.type(), s.max_lens()};
}
if(s.broadcasted()) if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {s.type(), s.lens()};
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -53,10 +53,7 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -53,10 +53,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
auto mb_maxes = info.add_instruction( auto cdf = info.add_common_op("sub", args[0], maxes);
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
maxes);
auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes);
// Take the element-wise exponent to get probabilities in the range (0, 1] // Take the element-wise exponent to get probabilities in the range (0, 1]
cdf = info.add_instruction(migraphx::make_op("exp"), cdf); cdf = info.add_instruction(migraphx::make_op("exp"), cdf);
// Compute the cumulative density function // Compute the cumulative density function
...@@ -69,7 +66,7 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -69,7 +66,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
gen.seed(info.attributes.at("seed").f()); gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front(); size_t batch_size = args[0]->get_shape().max_lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
std::vector<float> random_dist(batch_size * sample_size); std::vector<float> random_dist(batch_size * sample_size);
......
...@@ -4327,6 +4327,23 @@ def multinomial_test(): ...@@ -4327,6 +4327,23 @@ def multinomial_test():
return ([node], [input], [output]) return ([node], [input], [output])
@onnx_test()
def multinomial_dyn_test():
sample_size = 10
seed = 0.0
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [None, 10])
output = helper.make_tensor_value_info("output", TensorProto.INT32,
[None, 10])
node = onnx.helper.make_node('Multinomial',
inputs=['input'],
sample_size=sample_size,
seed=seed,
outputs=['output'])
return ([node], [input], [output])
@onnx_test() @onnx_test()
def multinomial_generated_seed_test(): def multinomial_generated_seed_test():
sample_size = 10 sample_size = 10
......
...@@ -4104,6 +4104,40 @@ TEST_CASE(multinomial_test) ...@@ -4104,6 +4104,40 @@ TEST_CASE(multinomial_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(multinomial_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 10;
float seed = 0.0f;
auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {{1, 10}, {10, 10}}});
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
// auto mb_maxes =
// mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 10}}}), maxes);
// auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
auto cdf = add_common_op(*mm, migraphx::make_op("sub"), {input, maxes});
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
mm->add_return({ret});
// auto prog = optimize_onnx("multinomial_dyn_test.onnx");
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
options.print_program_on_error = true;
auto prog = migraphx::parse_onnx("multinomial_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(multinomial_dtype_error_test) TEST_CASE(multinomial_dtype_error_test)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); }));
......
...@@ -1748,6 +1748,13 @@ TEST_CASE(multinomial) ...@@ -1748,6 +1748,13 @@ TEST_CASE(multinomial)
throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s); throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s);
} }
TEST_CASE(multinomial_dyn)
{
migraphx::shape s{migraphx::shape::int32_type, {{2, 3}, {5, 6}}};
expect_shape(s, migraphx::make_op("multinomial", {{"dtype", migraphx::shape::int32_type}}), s, s);
}
TEST_CASE(nms_shape) TEST_CASE(nms_shape)
{ {
// use_dyn_output == false // use_dyn_output == false
......
...@@ -4915,6 +4915,56 @@ TEST_CASE(multinomial_test) ...@@ -4915,6 +4915,56 @@ TEST_CASE(multinomial_test)
EXPECT(migraphx::verify_range(norm, res_norm, 100000)); EXPECT(migraphx::verify_range(norm, res_norm, 100000));
} }
TEST_CASE(multinomial_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 100000;
float seed = 0.0f;
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {{1, 2}, {5, 6}}};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
auto input = mm->add_literal(migraphx::literal(s, data));
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto mb_maxes =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5}}}), maxes);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int32_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int> res_dist(5, 0);
for(const auto& r : result_vec)
res_dist[r]++;
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
TEST_CASE(neg_test) TEST_CASE(neg_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment