Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
24cc6ea8
Commit
24cc6ea8
authored
Aug 21, 2018
by
Paul
Browse files
Fix bn tests
parent
e01c70e6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
18 deletions
+42
-18
src/generate.cpp
src/generate.cpp
+10
-2
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+11
-8
src/include/migraph/literal.hpp
src/include/migraph/literal.hpp
+13
-0
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+8
-8
No files found.
src/generate.cpp
View file @
24cc6ea8
...
...
@@ -2,7 +2,7 @@
namespace
migraph
{
argument
generate_argument
(
shape
s
,
std
::
mt19937
::
result_type
seed
)
argument
generate_argument
(
shape
s
,
unsigned
long
seed
)
{
argument
result
;
s
.
visit_type
([
&
](
auto
as
)
{
...
...
@@ -13,7 +13,7 @@ argument generate_argument(shape s, std::mt19937::result_type seed)
return
result
;
}
literal
generate_literal
(
shape
s
,
std
::
mt19937
::
result_type
seed
)
literal
generate_literal
(
shape
s
,
unsigned
long
seed
)
{
literal
result
;
s
.
visit_type
([
&
](
auto
as
)
{
...
...
@@ -24,4 +24,12 @@ literal generate_literal(shape s, std::mt19937::result_type seed)
return
result
;
}
// TODO: Move to literal.cpp
literal
abs
(
literal
l
)
{
return
transform
(
l
,
[](
auto
x
)
{
return
std
::
fabs
(
x
);
});
}
}
// namespace migraph
src/include/migraph/generate.hpp
View file @
24cc6ea8
...
...
@@ -8,7 +8,7 @@
namespace
migraph
{
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_floating_point
<
T
>{})
>
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
if
(
z
==
0
)
return
0
;
...
...
@@ -16,7 +16,7 @@ T normalize(unsigned long z)
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
std
::
is_signed
<
T
>{}
and
not
std
::
is_floating_point
<
T
>
{})
>
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
half_max
=
max
/
2
;
...
...
@@ -24,7 +24,7 @@ T normalize(unsigned long z)
}
template
<
class
T
,
MIGRAPH_REQUIRES
(
not
std
::
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
return
z
%
max
;
...
...
@@ -33,9 +33,10 @@ T normalize(unsigned long z)
template
<
class
T
>
struct
xorshf96_generator
{
unsigned
long
seed
=
0
;
unsigned
long
x
=
123456789
;
unsigned
long
y
=
362436069
;
unsigned
long
z
=
521288629
;
unsigned
long
z
=
521288629
^
seed
;
constexpr
T
operator
()()
noexcept
{
...
...
@@ -53,16 +54,18 @@ struct xorshf96_generator
};
template
<
class
T
>
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
std
::
mt19937
::
result_type
)
std
::
vector
<
T
>
generate_tensor_data
(
const
migraph
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{});
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
return
result
;
}
argument
generate_argument
(
shape
s
,
std
::
mt19937
::
result_type
seed
=
0
);
argument
generate_argument
(
shape
s
,
unsigned
long
seed
=
0
);
literal
generate_literal
(
shape
s
,
std
::
mt19937
::
result_type
seed
=
0
);
literal
generate_literal
(
shape
s
,
unsigned
long
seed
=
0
);
literal
abs
(
literal
l
);
}
// namespace migraph
...
...
src/include/migraph/literal.hpp
View file @
24cc6ea8
...
...
@@ -94,6 +94,19 @@ struct literal : raw_data<literal>
}
};
template
<
class
F
>
literal
transform
(
literal
l
,
F
f
)
{
literal
result
;
l
.
visit
([
&
](
auto
x
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
x
)
::
value_type
>
;
std
::
vector
<
type
>
output
(
x
.
size
(),
0.0
);
std
::
transform
(
x
.
begin
(),
x
.
end
(),
output
.
begin
(),
f
);
result
=
literal
{
l
.
get_shape
(),
output
};
});
return
result
;
}
}
// namespace migraph
#endif
test/gpu/miopen.cpp
View file @
24cc6ea8
...
...
@@ -332,10 +332,10 @@ struct test_batchnorm_inference_2
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
batches
,
channels
,
height
,
width
}};
migraph
::
shape
vars
{
migraph
::
shape
::
float_type
,
{
channels
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
mean
=
p
.
add_
parameter
(
"mean"
,
vars
);
auto
variance
=
p
.
add_
parameter
(
"variance"
,
vars
);
auto
scale
=
p
.
add_
parameter
(
"scale"
,
vars
);
auto
bias
=
p
.
add_
parameter
(
"bias"
,
vars
);
auto
mean
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
0
))
);
auto
variance
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
1
))
);
auto
scale
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
2
))
);
auto
bias
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
3
))
);
p
.
add_instruction
(
migraph
::
batch_norm_inference
{},
x
,
mean
,
variance
,
scale
,
bias
);
return
p
;
}
...
...
@@ -355,10 +355,10 @@ struct test_batchnorm_inference
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
batches
,
channels
,
height
,
width
}};
migraph
::
shape
vars
{
migraph
::
shape
::
float_type
,
{
channels
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
mean
=
p
.
add_
parameter
(
"mean"
,
vars
);
auto
variance
=
p
.
add_
parameter
(
"variance"
,
vars
);
auto
scale
=
p
.
add_
parameter
(
"scale"
,
vars
);
auto
bias
=
p
.
add_
parameter
(
"bias"
,
vars
);
auto
mean
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
0
))
);
auto
variance
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
1
))
);
auto
scale
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
2
))
);
auto
bias
=
p
.
add_
literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
3
))
);
p
.
add_instruction
(
migraph
::
batch_norm_inference
{},
x
,
mean
,
variance
,
scale
,
bias
);
return
p
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment