Commit 634f5844 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

clean up merge bugs

parent 6fc3f195
...@@ -46,7 +46,7 @@ namespace op { ...@@ -46,7 +46,7 @@ namespace op {
struct rand_uniform struct rand_uniform
{ {
uint32_t sample_size = {23}; uint32_t sample_size = {20};
uint32_t seed = {0}; uint32_t seed = {0};
shape::type_t dtype = shape::type_t::float_type; shape::type_t dtype = shape::type_t::float_type;
...@@ -61,15 +61,12 @@ struct rand_uniform ...@@ -61,15 +61,12 @@ struct rand_uniform
std::string name() const { return "rand_uniform"; } std::string name() const { return "rand_uniform"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() > 0)
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s = inputs.front(); auto s = inputs.front();
if(s.dynamic()) if(s.dynamic())
{ {
// return s; return s;
return {dtype, {s.dyn_dims()[0], {sample_size, sample_size}}};
} }
else if(s.broadcasted()) else if(s.broadcasted())
{ {
...@@ -77,15 +74,9 @@ struct rand_uniform ...@@ -77,15 +74,9 @@ struct rand_uniform
} }
else else
{ {
// For static input, return the input shape. Assume the batch_size and sample_size
// have already been factored in. This saves us from reallocating a shape at
// runtime when the input is a literal.
return s.with_lens(s.lens()); return s.with_lens(s.lens());
} }
} }
// No input instruction is required. 1-dimensional static output.
return shape{dtype, {sample_size}};
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
......
...@@ -5366,7 +5366,7 @@ TEST_CASE(multinomial_dyn_test) ...@@ -5366,7 +5366,7 @@ TEST_CASE(multinomial_dyn_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000)); EXPECT(migraphx::verify::verify_range(norm, res_norm, 1000000));
} }
TEST_CASE(neg_test) TEST_CASE(neg_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment