Unverified Commit 594f2802 authored by turneram's avatar turneram Committed by GitHub
Browse files

Fix time seed bug in random sequence ops (#1027)

Fix bug caused by casting time seed to float
parent 46b0c33b
...@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial>
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
...@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial>
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution // Pre-compute random distribution
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front(); size_t batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
......
...@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> ...@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
if(contains(info.attributes, "scale")) if(contains(info.attributes, "scale"))
scale = info.attributes.at("scale").f(); scale = info.attributes.at("scale").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape; shape out_shape;
if(contains(info.attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
...@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> ...@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
": cannot deduce shape without shape attribute or argument."); ": cannot deduce shape without shape attribute or argument.");
} }
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::normal_distribution<> d(mean, scale); std::normal_distribution<> d(mean, scale);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
...@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if(contains(info.attributes, "low")) if(contains(info.attributes, "low"))
low = info.attributes.at("low").f(); low = info.attributes.at("low").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape; shape out_shape;
if(contains(info.attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
...@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
": cannot deduce shape without shape attribute or argument."); ": cannot deduce shape without shape attribute or argument.");
} }
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low); std::uniform_real_distribution<> d(high, low);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
...@@ -2725,6 +2725,21 @@ def multinomial_test(): ...@@ -2725,6 +2725,21 @@ def multinomial_test():
return ([node], [input], [output]) return ([node], [input], [output])
@onnx_test
def multinomial_generated_seed_test():
sample_size = 10
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10])
output = helper.make_tensor_value_info("output", TensorProto.INT32,
[1, 10])
node = onnx.helper.make_node('Multinomial',
inputs=['input'],
sample_size=sample_size,
outputs=['output'])
return ([node], [input], [output])
@onnx_test @onnx_test
def multinomial_dtype_error_test(): def multinomial_dtype_error_test():
sample_size = 10 sample_size = 10
...@@ -3176,6 +3191,21 @@ def randomnormal_dtype_error_test(): ...@@ -3176,6 +3191,21 @@ def randomnormal_dtype_error_test():
return ([node], [], [output]) return ([node], [], [output])
@onnx_test
def randomnormal_generated_seed_test():
sample_size = 10
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10])
output = helper.make_tensor_value_info("output", TensorProto.INT32,
[1, 10])
node = onnx.helper.make_node('RandomNormal',
inputs=['input'],
sample_size=sample_size,
outputs=['output'])
return ([node], [input], [output])
@onnx_test @onnx_test
def randomnormal_shape_error_test(): def randomnormal_shape_error_test():
dtype = 1 dtype = 1
...@@ -3266,6 +3296,21 @@ def randomuniform_dtype_error_test(): ...@@ -3266,6 +3296,21 @@ def randomuniform_dtype_error_test():
return ([node], [], [output]) return ([node], [], [output])
@onnx_test
def randomuniform_generated_seed_test():
sample_size = 10
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10])
output = helper.make_tensor_value_info("output", TensorProto.INT32,
[1, 10])
node = onnx.helper.make_node('RandomUniform',
inputs=['input'],
sample_size=sample_size,
outputs=['output'])
return ([node], [input], [output])
@onnx_test @onnx_test
def randomuniform_shape_error_test(): def randomuniform_shape_error_test():
dtype = 1 dtype = 1
......
No preview for this file type
multinomial_generated_seed_test:
0
inputoutput" Multinomial*
sample_size
multinomial_generated_seed_testZ
input



b
output



B
\ No newline at end of file
...@@ -2388,6 +2388,14 @@ TEST_CASE(multinomial_dtype_error_test) ...@@ -2388,6 +2388,14 @@ TEST_CASE(multinomial_dtype_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); }));
} }
TEST_CASE(multinomial_generated_seed_test)
{
auto p1 = optimize_onnx("multinomial_generated_seed_test.onnx");
auto p2 = optimize_onnx("multinomial_generated_seed_test.onnx");
EXPECT(p1 != p2);
}
TEST_CASE(multinomial_int64_test) TEST_CASE(multinomial_int64_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2891,6 +2899,14 @@ TEST_CASE(randomnormal_dtype_error_test) ...@@ -2891,6 +2899,14 @@ TEST_CASE(randomnormal_dtype_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_dtype_error_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_dtype_error_test.onnx"); }));
} }
TEST_CASE(randomnormal_generated_seed_test)
{
auto p1 = optimize_onnx("randomnormal_generated_seed_test.onnx");
auto p2 = optimize_onnx("randomnormal_generated_seed_test.onnx");
EXPECT(p1 != p2);
}
TEST_CASE(randomnormal_shape_error_test) TEST_CASE(randomnormal_shape_error_test)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_shape_error_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_shape_error_test.onnx"); }));
...@@ -2953,6 +2969,14 @@ TEST_CASE(randomuniform_dtype_error_test) ...@@ -2953,6 +2969,14 @@ TEST_CASE(randomuniform_dtype_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_dtype_error_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_dtype_error_test.onnx"); }));
} }
TEST_CASE(randomuniform_generated_seed_test)
{
auto p1 = optimize_onnx("randomuniform_generated_seed_test.onnx");
auto p2 = optimize_onnx("randomuniform_generated_seed_test.onnx");
EXPECT(p1 != p2);
}
TEST_CASE(randomuniform_shape_error_test) TEST_CASE(randomuniform_shape_error_test)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_shape_error_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_shape_error_test.onnx"); }));
......
 randomnormal_generated_seed_test:
1
inputoutput" RandomNormal*
sample_size
 randomnormal_generated_seed_testZ
input



b
output



B
\ No newline at end of file
!randomuniform_generated_seed_test:
2
inputoutput" RandomUniform*
sample_size
!randomuniform_generated_seed_testZ
input



b
output



B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment