Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e7c8ac31
Commit
e7c8ac31
authored
Jul 24, 2023
by
Brian Pickrell
Browse files
Refactored inputs for rand_uniform and refactored multinomial_dyn_test.
parent
f240996d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
39 deletions
+42
-39
src/include/migraphx/op/multinomial.hpp
src/include/migraphx/op/multinomial.hpp
+4
-2
src/include/migraphx/op/rand_uniform.hpp
src/include/migraphx/op/rand_uniform.hpp
+11
-15
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+27
-22
No files found.
src/include/migraphx/op/multinomial.hpp
View file @
e7c8ac31
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
* each category, or bucket. This does not require the standard multinomial
* each category, or bucket. This does not require the standard multinomial
* distribution but instead takes a probability distribution as an input.
* distribution but instead takes a probability distribution as an input.
*
*
* In the large number limit, the fractional counts approach the multinomial distribution.
*
* Inputs: args[0] - a vector of probabilities for each category. Values are running totals
* Inputs: args[0] - a vector of probabilities for each category. Values are running totals
as provided by op prefix_scan_sum.
as provided by op prefix_scan_sum.
* Values are log normalized (i.e. start with any set of numbers > 0, then
* Values are log normalized (i.e. start with any set of numbers > 0, then
...
@@ -87,7 +89,7 @@ struct multinomial
...
@@ -87,7 +89,7 @@ struct multinomial
// return a static shape
// return a static shape
if
((
not
inputs
.
front
().
dynamic
())
or
(
inputs
.
front
().
dyn_dims
().
front
().
is_fixed
()))
if
((
not
inputs
.
front
().
dynamic
())
or
(
inputs
.
front
().
dyn_dims
().
front
().
is_fixed
()))
{
{
if
((
not
inputs
.
back
().
dynamic
())
or
(
inputs
.
back
().
dyn_dims
().
front
().
is_fixed
()))
if
((
not
inputs
.
back
().
dynamic
())
or
(
inputs
.
back
().
dyn_dims
().
back
().
is_fixed
()))
{
{
size_t
batch
=
{
inputs
.
front
().
max_lens
().
front
()};
size_t
batch
=
{
inputs
.
front
().
max_lens
().
front
()};
size_t
sample_size
{
inputs
.
back
().
max_lens
().
back
()};
size_t
sample_size
{
inputs
.
back
().
max_lens
().
back
()};
...
@@ -96,7 +98,7 @@ struct multinomial
...
@@ -96,7 +98,7 @@ struct multinomial
}
}
return
{
dtype
,
return
{
dtype
,
{
inputs
.
front
().
to_dynamic
().
dyn_dims
().
front
(),
{
inputs
.
front
().
to_dynamic
().
dyn_dims
().
front
(),
inputs
.
back
().
to_dynamic
().
dyn_dims
().
front
()}};
inputs
.
back
().
to_dynamic
().
dyn_dims
().
back
()}};
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/rand_uniform.hpp
View file @
e7c8ac31
...
@@ -47,20 +47,17 @@ namespace op {
...
@@ -47,20 +47,17 @@ namespace op {
struct
rand_uniform
struct
rand_uniform
{
{
uint32_t
sample_size
=
{
20
};
uint32_t
sample_size
=
{
20
};
uint32_t
seed
=
{
0
};
uint32_t
seed
=
{
0
};
shape
::
type_t
dtype
=
shape
::
type_t
::
float_type
;
shape
::
type_t
dtype
=
shape
::
type_t
::
float_type
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
dtype
,
"dtype"
),
f
(
self
.
sample_size
,
"sample_size"
),
f
(
self
.
seed
,
"seed"
));
return
pack
(
}
f
(
self
.
dtype
,
"dtype"
),
f
(
self
.
sample_size
,
"sample_size"
),
f
(
self
.
seed
,
"seed"
));
value
attributes
()
const
{
return
{{
"sample_size"
,
sample_size
},
{
"seed"
,
seed
}};
}
}
value
attributes
()
const
{
return
{{
"sample_size"
,
sample_size
},
{
"seed"
,
seed
}};
}
std
::
string
name
()
const
{
return
"rand_uniform"
;
}
std
::
string
name
()
const
{
return
"rand_uniform"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
@@ -81,23 +78,22 @@ struct rand_uniform
...
@@ -81,23 +78,22 @@ struct rand_uniform
}
}
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
(
void
)
args
;
// suppress compiler warning
argument
result
{
dyn_out
.
computed_shape
};
argument
result
{
dyn_out
.
computed_shape
};
std
::
mt19937
gen
(
seed
);
std
::
mt19937
gen
(
seed
);
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
std
::
uniform_real_distribution
<>
dis
(
0.0
,
1.0
);
size_t
index
(
dyn_out
.
computed_shape
.
elements
());
size_t
elts
(
dyn_out
.
computed_shape
.
elements
());
// Use of our visitor and par_for replaces a call like
// Use of our visitor and par_for replaces a call like
// std::vector<float> rand_samples(sample_size);
// std::vector<float> rand_samples(sample_size);
// std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
// std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
par_for
(
sample_size
,
[
&
](
auto
i
)
par_for
(
elts
,
[
&
](
auto
i
)
{
{
output
[
i
]
=
dis
(
gen
);
output
[
i
]
=
dis
(
gen
);
// output[i] = rand_samples[i];
// output[i] = rand_samples[i];
});
});
});
});
return
result
;
return
result
;
...
...
test/ref_ops_test.cpp
View file @
e7c8ac31
...
@@ -4909,7 +4909,6 @@ TEST_CASE(multinomial_test)
...
@@ -4909,7 +4909,6 @@ TEST_CASE(multinomial_test)
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::vector<float> data(5);
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
auto input = mm->add_literal(migraphx::literal(s, data));
auto input = mm->add_literal(migraphx::literal(s, data));
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input);
...
@@ -4927,8 +4926,8 @@ printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
...
@@ -4927,8 +4926,8 @@ printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
std::vector<int32_t> result_vec(sample_size);
std::vector<int32_t> result_vec(sample_size);
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// res_dist is a count, or histogram, of the number of samples in each category. This is the
sampled
// res_dist is a count, or histogram, of the number of samples in each category. This is the
// distribution.
//
sampled
distribution.
std::vector<int> res_dist(5, 0);
std::vector<int> res_dist(5, 0);
for(const auto& r : result_vec)
for(const auto& r : result_vec)
res_dist[r]++;
res_dist[r]++;
...
@@ -4937,7 +4936,7 @@ printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
...
@@ -4937,7 +4936,7 @@ printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
// and the sampling result res_dist; they should be close
// and the sampling result res_dist; they should be close
// Total the unnormalized probabilities
// Total the unnormalized probabilities
auto dist_sum
= std::accumulate(dist.begin(), dist.end(), 0);
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
// Total the number of values returned
// Total the number of values returned
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
...
@@ -4949,7 +4948,7 @@ printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
...
@@ -4949,7 +4948,7 @@ printf("data="); for(auto aa:data)printf(", %f", aa);printf("\n");
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
return static_cast<double>(n) / res_dist_sum;
});
});
printf("cumulative distribution of result ="); for(auto aa:res_norm)printf(", %f", aa);printf("\n");
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
}
...
@@ -4960,31 +4959,35 @@ TEST_CASE(multinomial_dyn_test)
...
@@ -4960,31 +4959,35 @@ TEST_CASE(multinomial_dyn_test)
size_t sample_size = 1000000;
size_t sample_size = 1000000;
float seed = 0.0f;
float seed = 0.0f;
// Shape of the random data
// Shape of the random data
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2
3
, sample_size + 1}}};
migraphx::shape rs{migraphx::shape::float_type, {{1, 2}, {2, sample_size + 1}}};
auto input = mm->add_parameter("Input_1", rs);
auto input = mm->add_parameter("Input_1", rs);
// the probability distribution, which also defines the number of categories
//
Shape of
the probability distribution, which also defines the number of categories
migraphx::shape s{migraphx::shape::float_type, {{1,
2
}, {5, 6}}};
migraphx::shape s{migraphx::shape::float_type, {{1,
1
}, {5, 6}}};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<int> dist{15, 25, 15, 25, 20};
std::vector<float> data(5);
std::vector<float> data(5);
// todo: make this an instruction too
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return d; });
std::transform(dist.begin(), dist.end(), data.begin(), [&](auto d) { return std::log(d); });
auto input2 = mm->add_parameter("Input_2", s);
auto input2 = mm->add_parameter("Input_2", s);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input2);
// The next several instructions log-normalize the probability distribution,
auto mb_maxes =
// as required by the multinomial operation
mm->add_instruction(migraphx::make_op("multibroadcast"), maxes, input2);
auto logs = mm->add_instruction(migraphx::make_op("log"), input2);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), input2, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), logs);
cdf = mm->add_instruction(
auto mb_maxes = mm->add_instruction(migraphx::make_op("multibroadcast"), maxes, input2);
auto cdf = mm->add_instruction(migraphx::make_op("sub"), logs, mb_maxes);
cdf = mm->add_instruction(migraphx::make_op("exp"), cdf);
cdf = mm->add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
auto randoms = mm->add_instruction(migraphx::make_op("rand_uniform", {{"seed", seed}}), input);
auto randoms = mm->add_instruction(migraphx::make_op("rand_uniform", {{"seed", seed}}), input);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
mm->add_instruction(migraphx::make_op("multinomial"), cdf, randoms);
p.compile(migraphx::make_target("ref"));
p.compile(migraphx::make_target("ref"));
// Create a dummy input in the shape we want for the random data
// Create a dummy input in the shape we want for the random data
...
@@ -4992,16 +4995,20 @@ TEST_CASE(multinomial_dyn_test)
...
@@ -4992,16 +4995,20 @@ TEST_CASE(multinomial_dyn_test)
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, sample_size}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, sample_size}};
migraphx::shape input_fixed_shape2{migraphx::shape::float_type, {1, 5}};
migraphx::shape input_fixed_shape2{migraphx::shape::float_type, {1, 5}};
migraphx::parameter_map params0;
migraphx::parameter_map params0;
params0["Input_1"] =
migraphx::argument(input_fixed_shape1, dummy.data());
params0["Input_1"] = migraphx::argument(input_fixed_shape1, dummy.data());
params0["Input_2"] = migraphx::argument(input_fixed_shape2, data.data());
params0["Input_2"] = migraphx::argument(input_fixed_shape2, data.data());
auto result = p.eval(params0).back();
auto result
= p.eval(params0).back();
std::vector<
int32_
t> result_vec(input_fixed_shape2.elements());
std::vector<
floa
t> result_vec(input_fixed_shape2.elements());
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
// Make a categorical histogram of output
std::vector<int> res_dist(5, 0);
std::vector<int> res_dist(5, 0);
for(const auto& r : result_vec)
for(const auto& r : result_vec)
res_dist[r]++;
res_dist[r]++;
// Rescale or normalize both the input probability distribution and the output
// histogram, and compare. Should be close but not identical.
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
auto dist_sum = std::accumulate(dist.begin(), dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
auto res_dist_sum = std::accumulate(res_dist.begin(), res_dist.end(), 0);
std::vector<float> norm(5);
std::vector<float> norm(5);
...
@@ -5012,8 +5019,6 @@ TEST_CASE(multinomial_dyn_test)
...
@@ -5012,8 +5019,6 @@ TEST_CASE(multinomial_dyn_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum;
return static_cast<double>(n) / res_dist_sum;
});
});
printf("cumulative distribution of input ="); for(auto aa:norm)printf(", %f", aa);printf("\n");
printf("cumulative distribution of result ="); for(auto aa:res_norm)printf(", %f", aa);printf("\n");
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
EXPECT(migraphx::verify_range(norm, res_norm, 100000));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment