Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
50da6a8c
Commit
50da6a8c
authored
Aug 23, 2023
by
Brian Pickrell
Browse files
misc code cleanup. Seed can be any type.
parent
aa517bd9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
19 deletions
+14
-19
src/include/migraphx/op/random_uniform.hpp
src/include/migraphx/op/random_uniform.hpp
+11
-16
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+3
-3
No files found.
src/include/migraphx/op/random_uniform.hpp
View file @
50da6a8c
...
...
@@ -44,35 +44,30 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/reflect.hpp>
#include <random>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* random_uniform populates the passed shape with random numbers, in a uniform
* distribution. Range for floating-point data types is (0, 1);
* for integer types it is [0, <max value for the type>]
*
* Input 1: seed
* Input 2: output shape
*/
struct
random_uniform
{
// The random_uniform operation does not contain a random number generator seed
// as a member, and expects it to be passed as a runtime input.
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
();
}
// The random_uniform operation needs the random number generator seed
// to be passed as a runtime input.
/**
* Input 1: seed
* Input 2: output shape
*/
std
::
string
name
()
const
{
return
"random_uniform"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
if
(
inputs
.
front
().
type
()
!=
shape
::
type_t
::
uint64_type
)
MIGRAPHX_THROW
(
"RANDOM_UNIFORM: Input 1 (seed) must have type long unsigned int"
);
auto
s
=
inputs
.
at
(
1
);
if
(
s
.
dynamic
())
{
...
...
@@ -98,7 +93,7 @@ struct random_uniform
{
// default range for all integer types is (0,
// std::uniform_int_distribution<type>::max()).
// To
clamp to a
different range
, apply min or max ops. to the output of this.
// To
do: enable
different range
s
std
::
uniform_int_distribution
<
type
>
dis
;
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
}
...
...
test/ref_ops_test.cpp
View file @
50da6a8c
...
...
@@ -6498,7 +6498,7 @@ TEST_CASE(random_uniform_int_test)
// random uniform distribution with an integer type input shape
migraphx::program p;
auto* mm = p.get_main_module();
uint64_
t seed(0);
floa
t seed(0
.1
);
size_t sample_size(200);
// Shape of the random data
...
...
@@ -6509,8 +6509,8 @@ TEST_CASE(random_uniform_int_test)
auto input = mm->add_literal(migraphx::literal(rs, data));
// Runtime randomization seed
migraphx::shape seed_shape{migraphx::shape::
uint64
_type, {1}};
std::vector<
uint64_
t> seed_data{seed};
migraphx::shape seed_shape{migraphx::shape::
float
_type, {1}};
std::vector<
floa
t> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("random_uniform"), seed_input, input);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment