Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
594f2802
Unverified
Commit
594f2802
authored
Jan 05, 2022
by
turneram
Committed by
GitHub
Jan 05, 2022
Browse files
Fix time seed bug in random sequence ops (#1027)
Fix bug caused by casting time seed to float
parent
46b0c33b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
126 additions
and
18 deletions
+126
-18
src/onnx/parse_multinomial.cpp
src/onnx/parse_multinomial.cpp
+4
-6
src/onnx/parse_randomnormal_ops.cpp
src/onnx/parse_randomnormal_ops.cpp
+4
-6
src/onnx/parse_randomuniform_ops.cpp
src/onnx/parse_randomuniform_ops.cpp
+4
-6
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+45
-0
test/onnx/gen_onnx.pyc
test/onnx/gen_onnx.pyc
+0
-0
test/onnx/multinomial_generated_seed_test.onnx
test/onnx/multinomial_generated_seed_test.onnx
+15
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+24
-0
test/onnx/randomnormal_generated_seed_test.onnx
test/onnx/randomnormal_generated_seed_test.onnx
+15
-0
test/onnx/randomuniform_generated_seed_test.onnx
test/onnx/randomuniform_generated_seed_test.onnx
+15
-0
No files found.
src/onnx/parse_multinomial.cpp
View file @
594f2802
...
...
@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial>
if
(
contains
(
info
.
attributes
,
"sample_size"
))
sample_size
=
info
.
attributes
.
at
(
"sample_size"
).
i
();
float
seed
=
static_cast
<
float
>
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
seed
=
info
.
attributes
.
at
(
"seed"
).
f
();
// Subtract the per-batch maximum log-probability, making the per-batch max 0
auto
maxes
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
args
[
0
]);
...
...
@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial>
migraphx
::
make_op
(
"prefix_scan_sum"
,
{{
"axis"
,
1
},
{
"exclusive"
,
false
}}),
cdf
);
// Pre-compute random distribution
std
::
mt19937
gen
(
seed
);
std
::
mt19937
gen
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
size_t
batch_size
=
args
[
0
]
->
get_shape
().
lens
().
front
();
migraphx
::
shape
dist_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}};
...
...
src/onnx/parse_randomnormal_ops.cpp
View file @
594f2802
...
...
@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
if
(
contains
(
info
.
attributes
,
"scale"
))
scale
=
info
.
attributes
.
at
(
"scale"
).
f
();
float
seed
=
static_cast
<
float
>
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
seed
=
info
.
attributes
.
at
(
"seed"
).
f
();
shape
out_shape
;
if
(
contains
(
info
.
attributes
,
"shape"
))
{
...
...
@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
": cannot deduce shape without shape attribute or argument."
);
}
std
::
mt19937
gen
(
seed
);
std
::
mt19937
gen
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
normal_distribution
<>
d
(
mean
,
scale
);
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
...
...
src/onnx/parse_randomuniform_ops.cpp
View file @
594f2802
...
...
@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if
(
contains
(
info
.
attributes
,
"low"
))
low
=
info
.
attributes
.
at
(
"low"
).
f
();
float
seed
=
static_cast
<
float
>
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
seed
=
info
.
attributes
.
at
(
"seed"
).
f
();
shape
out_shape
;
if
(
contains
(
info
.
attributes
,
"shape"
))
{
...
...
@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
": cannot deduce shape without shape attribute or argument."
);
}
std
::
mt19937
gen
(
seed
);
std
::
mt19937
gen
(
std
::
chrono
::
high_resolution_clock
::
now
().
time_since_epoch
().
count
());
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
d
(
high
,
low
);
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
...
...
test/onnx/gen_onnx.py
View file @
594f2802
...
...
@@ -2725,6 +2725,21 @@ def multinomial_test():
return
([
node
],
[
input
],
[
output
])
@
onnx_test
def
multinomial_generated_seed_test
():
sample_size
=
10
input
=
helper
.
make_tensor_value_info
(
"input"
,
TensorProto
.
FLOAT
,
[
1
,
10
])
output
=
helper
.
make_tensor_value_info
(
"output"
,
TensorProto
.
INT32
,
[
1
,
10
])
node
=
onnx
.
helper
.
make_node
(
'Multinomial'
,
inputs
=
[
'input'
],
sample_size
=
sample_size
,
outputs
=
[
'output'
])
return
([
node
],
[
input
],
[
output
])
@
onnx_test
def
multinomial_dtype_error_test
():
sample_size
=
10
...
...
@@ -3176,6 +3191,21 @@ def randomnormal_dtype_error_test():
return
([
node
],
[],
[
output
])
@
onnx_test
def
randomnormal_generated_seed_test
():
sample_size
=
10
input
=
helper
.
make_tensor_value_info
(
"input"
,
TensorProto
.
FLOAT
,
[
1
,
10
])
output
=
helper
.
make_tensor_value_info
(
"output"
,
TensorProto
.
INT32
,
[
1
,
10
])
node
=
onnx
.
helper
.
make_node
(
'RandomNormal'
,
inputs
=
[
'input'
],
sample_size
=
sample_size
,
outputs
=
[
'output'
])
return
([
node
],
[
input
],
[
output
])
@
onnx_test
def
randomnormal_shape_error_test
():
dtype
=
1
...
...
@@ -3266,6 +3296,21 @@ def randomuniform_dtype_error_test():
return
([
node
],
[],
[
output
])
@
onnx_test
def
randomuniform_generated_seed_test
():
sample_size
=
10
input
=
helper
.
make_tensor_value_info
(
"input"
,
TensorProto
.
FLOAT
,
[
1
,
10
])
output
=
helper
.
make_tensor_value_info
(
"output"
,
TensorProto
.
INT32
,
[
1
,
10
])
node
=
onnx
.
helper
.
make_node
(
'RandomUniform'
,
inputs
=
[
'input'
],
sample_size
=
sample_size
,
outputs
=
[
'output'
])
return
([
node
],
[
input
],
[
output
])
@
onnx_test
def
randomuniform_shape_error_test
():
dtype
=
1
...
...
test/onnx/gen_onnx.pyc
View file @
594f2802
No preview for this file type
test/onnx/multinomial_generated_seed_test.onnx
0 → 100644
View file @
594f2802
multinomial_generated_seed_test:
0
inputoutput"Multinomial*
sample_size
multinomial_generated_seed_testZ
input
b
output
B
\ No newline at end of file
test/onnx/onnx_test.cpp
View file @
594f2802
...
...
@@ -2388,6 +2388,14 @@ TEST_CASE(multinomial_dtype_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("multinomial_dtype_error_test.onnx"); }));
}
TEST_CASE(multinomial_generated_seed_test)
{
auto p1 = optimize_onnx("multinomial_generated_seed_test.onnx");
auto p2 = optimize_onnx("multinomial_generated_seed_test.onnx");
EXPECT(p1 != p2);
}
TEST_CASE(multinomial_int64_test)
{
migraphx::program p;
...
...
@@ -2891,6 +2899,14 @@ TEST_CASE(randomnormal_dtype_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_dtype_error_test.onnx"); }));
}
TEST_CASE(randomnormal_generated_seed_test)
{
auto p1 = optimize_onnx("randomnormal_generated_seed_test.onnx");
auto p2 = optimize_onnx("randomnormal_generated_seed_test.onnx");
EXPECT(p1 != p2);
}
TEST_CASE(randomnormal_shape_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("randomnormal_shape_error_test.onnx"); }));
...
...
@@ -2953,6 +2969,14 @@ TEST_CASE(randomuniform_dtype_error_test)
EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_dtype_error_test.onnx"); }));
}
TEST_CASE(randomuniform_generated_seed_test)
{
auto p1 = optimize_onnx("randomuniform_generated_seed_test.onnx");
auto p2 = optimize_onnx("randomuniform_generated_seed_test.onnx");
EXPECT(p1 != p2);
}
TEST_CASE(randomuniform_shape_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("randomuniform_shape_error_test.onnx"); }));
...
...
test/onnx/randomnormal_generated_seed_test.onnx
0 → 100644
View file @
594f2802
randomnormal_generated_seed_test:
1
inputoutput"RandomNormal*
sample_size
randomnormal_generated_seed_testZ
input
b
output
B
\ No newline at end of file
test/onnx/randomuniform_generated_seed_test.onnx
0 → 100644
View file @
594f2802
!randomuniform_generated_seed_test:
2
inputoutput" RandomUniform*
sample_size
!randomuniform_generated_seed_testZ
input
b
output
B
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment