Commit 38920ca7 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

work in progress for multinomial. Doesn't work

parent a0fa3742
......@@ -208,6 +208,8 @@ instruction_ref insert_common_op(module& m,
const operation& op,
std::vector<instruction_ref> inputs)
{
if(op.name() == "clip")
return inputs[0];
return m.insert_instruction(ins, op, insert_common_args(m, ins, std::move(inputs)));
}
......
......@@ -48,7 +48,7 @@ struct clip
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_type().same_dims();
check_shapes{inputs, *this, true}.has(3).same_type().same_dims();
return inputs.front();
}
......
......@@ -47,22 +47,22 @@ struct multinomial
std::string name() const { return "multinomial"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).only_dims(2);
size_t sample_size = inputs.back().lens().back();
check_shapes{inputs, *this, true}.has(2).only_dims(2);
size_t sample_size = inputs.back().max_lens().back();
if(not contains({shape::int32_type, shape::int64_type}, dtype))
MIGRAPHX_THROW(
"Multinomial: Invalid output type. Valid types are int32_type and int64_type.");
return {dtype, {inputs.front().lens().front(), sample_size}};
return inputs.front().normalize_standard();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
size_t batch_size = output_shape.lens().front();
size_t class_size = args[0].get_shape().lens().back();
size_t sample_size = output_shape.lens().back();
argument result{dyn_out.computed_shape};
size_t batch_size = dyn_out.computed_shape.lens().front();
size_t class_size = args[0].get_shape().max_lens().back();
size_t sample_size = dyn_out.computed_shape.lens().back();
visit_all(args[0], args[1])([&](auto cdf, auto dist) {
result.visit([&](auto output) {
......
......@@ -60,8 +60,12 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto s = inputs.front();
if(s.dynamic())
{
return {s.type(), s.max_lens()};
}
if(s.broadcasted())
{
return {s.type(), s.lens()};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -53,10 +53,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
// Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
auto mb_maxes = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
maxes);
auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes);
auto cdf = info.add_common_op("sub", args[0], maxes);
// Take the element-wise exponent to get probabilities in the range (0, 1]
cdf = info.add_instruction(migraphx::make_op("exp"), cdf);
// Compute the cumulative density function
......@@ -69,7 +66,7 @@ struct parse_multinomial : op_parser<parse_multinomial>
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front();
size_t batch_size = args[0]->get_shape().max_lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
std::vector<float> random_dist(batch_size * sample_size);
......
......@@ -4327,6 +4327,23 @@ def multinomial_test():
return ([node], [input], [output])
@onnx_test()
def multinomial_dyn_test():
sample_size = 10
seed = 0.0
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [None, 10])
output = helper.make_tensor_value_info("output", TensorProto.INT32,
[None, 10])
node = onnx.helper.make_node('Multinomial',
inputs=['input'],
sample_size=sample_size,
seed=seed,
outputs=['output'])
return ([node], [input], [output])
@onnx_test()
def multinomial_generated_seed_test():
sample_size = 10
......
......@@ -4104,6 +4104,40 @@ TEST_CASE(multinomial_test)
EXPECT(p == prog);
}
TEST_CASE(multinomial_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 10;
float seed = 0.0f;
auto input = mm->add_parameter("input", migraphx::shape{migraphx::shape::float_type, {{1, 10}, {10, 10}}});
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
// auto mb_maxes =
// mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 10}}}), maxes);
// auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
auto cdf = add_common_op(*mm, migraphx::make_op("sub"), {input, maxes});
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
mm->add_return({ret});
// auto prog = optimize_onnx("multinomial_dyn_test.onnx");
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
options.print_program_on_error = true;
auto prog = migraphx::parse_onnx("multinomial_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(multinomial_dtype_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); }));
......
......@@ -1748,6 +1748,13 @@ TEST_CASE(multinomial)
throws_shape(migraphx::make_op("multinomial", {{"dtype", dtype}}), s, s);
}
TEST_CASE(multinomial_dyn)
{
migraphx::shape s{migraphx::shape::int32_type, {{2, 3}, {5, 6}}};
expect_shape(s, migraphx::make_op("multinomial", {{"dtype", migraphx::shape::int32_type}}), s, s);
}
TEST_CASE(nms_shape)
{
// use_dyn_output == false
......
......@@ -4915,6 +4915,56 @@ TEST_CASE(multinomial_test)
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
TEST_CASE(multinomial_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
size_t sample_size = 100000;
float seed = 0.0f;
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
migraphx::shape s{migraphx::shape::float_type, {{1, 2}, {5, 6}}};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
auto input = mm->add_literal(migraphx::literal(s, data));
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto mb_maxes =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5}}}), maxes);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<int32_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int> res_dist(5, 0);
for(const auto& r : result_vec)
res_dist[r]++;
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> res_norm(5);
std::transform(dist.begin(), dist.end(), norm.begin(), [&](auto n) {
return static_cast<double>(n) / dist_sum;
});
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
});
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
TEST_CASE(neg_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment