"results/vscode:/vscode.git/clone" did not exist on "af8bafcbfc293c5f58075d7c18da698c533655ad"
Commit d99729f8 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

work in progress. Doesn't run; needs dimensions_of op. to be merged

parent 0208239b
...@@ -49,6 +49,8 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -49,6 +49,8 @@ struct parse_multinomial : op_parser<parse_multinomial>
size_t sample_size = 1; size_t sample_size = 1;
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
else
MIGRAPHX_THROW("PARSE_MULTINOMIAL: sample_size not given");
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
...@@ -60,21 +62,57 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -60,21 +62,57 @@ struct parse_multinomial : op_parser<parse_multinomial>
cdf = info.add_instruction( cdf = info.add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution
std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0); // Make a shape that's the size of the sample set
size_t batch_size = args[0]->get_shape().max_lens().front(); shape s0 = args[0]->get_shape();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape dist_shape;
instruction_ref rand_dummy;
if(s0.dynamic())
{
dist_shape = {output_type, {s0.dyn_dims().front(), shape::dynamic_dimension({sample_size, sample_size})}};
auto temp = info.add_instruction(make_op("dimensions_of", {{"start", 0}, {"end", s0.ndim() - 1}}), args[0]);
auto asdf = temp->get_shape();
rand_dummy = info.add_instruction(migraphx::make_op("multibroadcast",
{{"out_dyn_dims", migraphx::to_value(dist_shape)}}), args[0], temp);
auto zap = rand_dummy->get_shape();
printf("hello %d\n", zap.ndim());
}
else
{
// use literal
size_t batch_size = s0.lens().front();
dist_shape = {output_type, {batch_size, sample_size}};
rand_dummy = info.add_literal(migraphx::literal{dist_shape, {batch_size, sample_size}});
// mul_random = info.add_instruction(migraphx::make_op("multibroadcast",
// {{"out_lens", migraphx::to_value(dist_shape)}}), args[0]);
// migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
}
// auto mul_random = info.add_instruction(migraphx::make_op("multibroadcast"
// ,{{"out_dyn_dims", migraphx::to_value(b)}}
// ), s0, dist_shape);
uint32_t seed(0);
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").i();
std::vector<float> random_dist(batch_size * sample_size); // how to populate data when dist_shape is dynamic? Answer: just send dist_shape`
std::generate(random_dist.begin(), random_dist.end(), [&]() { return dis(gen); }); // std::vector<float> data(dist_shape.elements(), 0.f);
auto dist_lit = info.add_literal(migraphx::literal{dist_shape, random_dist}); // auto dummy = info.add_literal(migraphx::literal(dist_shape, data));
auto randoms = info.add_instruction(migraphx::make_op("rand_uniform", {{"seed", seed}}), rand_dummy);
return info.add_instruction( return info.add_instruction(
migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, dist_lit); migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, randoms);
} }
}; };
......
...@@ -4162,15 +4162,17 @@ TEST_CASE(multinomial_dyn_test) ...@@ -4162,15 +4162,17 @@ TEST_CASE(multinomial_dyn_test)
cdf = mm->add_instruction( cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
std::mt19937 gen(seed); // std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0); // std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); // std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); // std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}}; migraphx::shape rs{migraphx::shape::float_type, {1, sample_size}};
auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); // auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples});
std::vector<float> data(rs.elements(), 0.3f);
auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, rs_lit); auto dummy = mm->add_literal(migraphx::literal(rs, data));
mm->add_return({ret}); auto randoms = mm->add_instruction(migraphx::make_op("rand_uniform", {{"seed", seed}}), dummy);
auto ret = mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
// mm->add_return({ret});
// auto prog = optimize_onnx("multinomial_dyn_test.onnx"); // auto prog = optimize_onnx("multinomial_dyn_test.onnx");
migraphx::onnx_options options; migraphx::onnx_options options;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment