Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2ef739bd
Commit
2ef739bd
authored
Aug 14, 2023
by
Brian Pickrell
Browse files
breaks ref_ops_test
parent
1f9e7402
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
36 deletions
+25
-36
src/include/migraphx/op/rand_uniform.hpp
src/include/migraphx/op/rand_uniform.hpp
+22
-20
test/op_shape_test.cpp
test/op_shape_test.cpp
+2
-10
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+1
-6
No files found.
src/include/migraphx/op/rand_uniform.hpp
View file @
2ef739bd
...
@@ -29,11 +29,11 @@
...
@@ -29,11 +29,11 @@
* be given as a runtime argument containing a single value, or a compile-time
* be given as a runtime argument containing a single value, or a compile-time
* attribute.
* attribute.
*
*
* Inputs: (1)
the shape of the set to be populated.
* Inputs: (1)
randomization seed (uint32)
* (2)
randomization seed (uint32). If not given at inference time, the attribute
* (2)
the shape of the set to be populated.
*
value, or auto seeding, will be used.
*
*
*
*
Attributes:
seed uint32 Randomization seed
* Attributes:
none
*
*
* Output: Same shape.
* Output: Same shape.
*
*
...
@@ -53,7 +53,8 @@ namespace op {
...
@@ -53,7 +53,8 @@ namespace op {
struct
rand_uniform
struct
rand_uniform
{
{
uint32_t
seed
=
{
0
};
// The rand_uniform operation does not contain a random number generator seed
// as a member, and expects it to be passed as a runtime input.
// todo: not currently settable
// todo: not currently settable
float
range_min
=
0.0
f
;
float
range_min
=
0.0
f
;
...
@@ -65,17 +66,21 @@ struct rand_uniform
...
@@ -65,17 +66,21 @@ struct rand_uniform
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
dtype
,
"dtype"
)
,
f
(
self
.
seed
,
"seed"
)
);
return
pack
(
f
(
self
.
dtype
,
"dtype"
));
}
}
/**
* Input 1: seed
* Input 2: output shape
*/
std
::
string
name
()
const
{
return
"rand_uniform"
;
}
std
::
string
name
()
const
{
return
"rand_uniform"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
2
);
if
(
inputs
.
size
()
>
1
and
inputs
.
a
t
(
1
).
type
()
!=
shape
::
type_t
::
uint32_type
)
if
(
inputs
.
fron
t
().
type
()
!=
shape
::
type_t
::
uint32_type
)
MIGRAPHX_THROW
(
"RAND_UNIFORM: Input 2 (seed) must have type unsigned int"
);
MIGRAPHX_THROW
(
"RAND_UNIFORM: Input 2 (seed) must have type unsigned int"
);
auto
s
=
inputs
.
fron
t
();
auto
s
=
inputs
.
a
t
(
1
);
if
(
s
.
dynamic
())
if
(
s
.
dynamic
())
{
{
return
s
.
with_type
(
dtype
);
return
s
.
with_type
(
dtype
);
...
@@ -86,25 +91,22 @@ struct rand_uniform
...
@@ -86,25 +91,22 @@ struct rand_uniform
}
}
}
}
argument
compute
(
const
dyn_output
&
dyn_o
ut
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
outp
ut
,
std
::
vector
<
argument
>
&
args
)
const
{
{
argument
result
{
dyn_out
.
computed_shape
};
(
void
)
output
;
argument
&
result
{
args
[
1
]};
auto
local_seed
(
seed
);
uint32_t
local_seed
=
args
[
0
].
at
<
uint32_t
>
(
0
);
if
(
args
.
size
()
>
1
)
{
local_seed
=
args
[
1
].
at
<
uint32_t
>
(
0
);
}
// If a seed argument was not defined, use the value from the seed attribute,
// or the default.
std
::
mt19937
gen
(
local_seed
);
std
::
mt19937
gen
(
local_seed
);
std
::
uniform_real_distribution
<>
dis
(
range_min
,
range_max
);
std
::
uniform_real_distribution
<>
dis
(
range_min
,
range_max
);
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
_shape
)
{
std
::
generate
(
output
.
begin
(),
output
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
std
::
generate
(
output
_shape
.
begin
(),
output
_shape
.
end
(),
[
&
]()
{
return
dis
(
gen
);
});
});
});
return
result
;
return
result
;
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
1
;
}
};
};
}
// namespace op
}
// namespace op
...
...
test/op_shape_test.cpp
View file @
2ef739bd
...
@@ -2219,17 +2219,9 @@ TEST_CASE(prefix_scan_sum_dyn_2d)
...
@@ -2219,17 +2219,9 @@ TEST_CASE(prefix_scan_sum_dyn_2d)
TEST_CASE
(
rand_uniform
)
TEST_CASE
(
rand_uniform
)
{
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
5
,
8
},
{
3
,
7
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
5
,
8
},
{
3
,
7
}};
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
1
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
dd
};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
dd
};
expect_shape
(
s1
,
migraphx
::
make_op
(
"rand_uniform"
,
{{
"seed"
,
1
}}),
s1
);
expect_shape
(
s1
,
migraphx
::
make_op
(
"rand_uniform"
),
s0
,
s1
);
}
TEST_CASE
(
rand_uniform_2args
)
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
5
,
8
},
{
3
,
7
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
dd
};
migraphx
::
shape
s2
{
migraphx
::
shape
::
uint32_type
,
dd
};
expect_shape
(
s1
,
migraphx
::
make_op
(
"rand_uniform"
,
{{
"seed"
,
1
}}),
s1
,
s2
);
}
}
TEST_CASE
(
random_seed
)
TEST_CASE
(
random_seed
)
...
...
test/ref_ops_test.cpp
View file @
2ef739bd
...
@@ -6477,12 +6477,7 @@ TEST_CASE(rand_uniform_test)
...
@@ -6477,12 +6477,7 @@ TEST_CASE(rand_uniform_test)
std::vector<uint32_t> seed_data{seed};
std::vector<uint32_t> seed_data{seed};
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
auto seed_input = mm->add_literal(migraphx::literal(seed_shape, seed_data));
mm->add_instruction(migraphx::make_op("rand_uniform",
mm->add_instruction(migraphx::make_op("rand_uniform"), seed_input, input);
{
{"seed", seed},
}),
input,
seed_input);
p.compile(migraphx::make_target("ref"));
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params0;
migraphx::parameter_map params0;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment