Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a4a5827
Unverified
Commit
0a4a5827
authored
Apr 02, 2019
by
mvermeulen
Committed by
GitHub
Apr 02, 2019
Browse files
Merge branch 'develop' into pad_op_rewrite
parents
78146d21
dc85aa6b
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
180 additions
and
22 deletions
+180
-22
src/eliminate_identity.cpp
src/eliminate_identity.cpp
+0
-1
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+14
-15
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+12
-0
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+3
-1
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+24
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+4
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+14
-0
test/eliminate_identity_test.cpp
test/eliminate_identity_test.cpp
+1
-1
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+105
-0
No files found.
src/eliminate_identity.cpp
View file @
0a4a5827
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
#include <utility>
...
...
src/fwd_conv_batchnorm_rewrite.cpp
View file @
0a4a5827
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
if
(
not
std
::
all_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[](
auto
arg
)
{
// Get scale, bias, mean, variance from inputs
return
arg
->
name
()
==
"@literal"
;
auto
gamma
=
ins
->
inputs
()[
1
]
->
eval
();
}))
auto
bias
=
ins
->
inputs
()[
2
]
->
eval
();
auto
mean
=
ins
->
inputs
()[
3
]
->
eval
();
auto
variance
=
ins
->
inputs
()[
4
]
->
eval
();
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
conv_ins
=
ins
->
inputs
()[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
continue
;
if
(
conv_ins
->
inputs
()[
1
]
->
name
()
!=
"@literal"
)
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
get_literal
();
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
get_literal
();
// Get convolution op
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
auto
conv_op
=
conv_ins
->
get_operator
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
argument
new_bias
{
{
bias
.
get_shape
()
.
type
(),
{
bias
.
get_shape
().
elements
()}}
};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
[
&
](
auto
weights2
,
auto
gamma2
,
auto
gamma2
,
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
new_bias2
[
c
]
=
bias2
(
c
)
-
(
gamma2
(
c
)
*
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
// Replace convolution instruction with updated weights
...
...
src/include/migraphx/check_shapes.hpp
View file @
0a4a5827
...
@@ -18,6 +18,11 @@ struct check_shapes
...
@@ -18,6 +18,11 @@ struct check_shapes
{
{
}
}
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
())
{
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
())
{}
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
())
{}
template
<
class
Op
>
template
<
class
Op
>
...
@@ -119,6 +124,13 @@ struct check_shapes
...
@@ -119,6 +124,13 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
}
template
<
class
F
>
template
<
class
F
>
bool
same
(
F
f
)
const
bool
same
(
F
f
)
const
{
{
...
...
src/include/migraphx/eliminate_identity.hpp
View file @
0a4a5827
...
@@ -11,7 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,7 +11,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program
;
struct
program
;
/**
/**
* Remove identity instructions.
* Remove identity instructions. Currently when used as the last pass, it will
* preserve the semantics of previous program state, therefore dead code elimination
* should not be used afterwards.
*/
*/
struct
eliminate_identity
struct
eliminate_identity
{
{
...
...
src/include/migraphx/operators.hpp
View file @
0a4a5827
...
@@ -56,6 +56,9 @@ struct batch_norm_inference
...
@@ -56,6 +56,9 @@ struct batch_norm_inference
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
.
data
(),
inputs
.
data
()
+
1
,
*
this
}.
only_dims
(
4
);
check_shapes
{
inputs
.
data
()
+
1
,
inputs
.
data
()
+
inputs
.
size
(),
*
this
}.
same_shape
().
elements
(
inputs
.
front
().
lens
()[
1
]);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
};
};
...
...
src/include/migraphx/ranges.hpp
View file @
0a4a5827
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
any_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
none_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/targets/cpu/lowering.cpp
View file @
0a4a5827
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
assert
((
variance
(
c
)
+
epsilon
)
>
0
);
assert
((
variance
[
c
]
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
(
c
)
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
(
c
)
)
/
result
(
n
,
c
,
h
,
w
)
=
gamma
[
c
]
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
[
c
]
)
/
std
::
sqrt
(
variance
(
c
)
+
epsilon
)
+
std
::
sqrt
(
variance
[
c
]
+
epsilon
)
+
bias
(
c
)
;
bias
[
c
]
;
});
});
});
});
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
0a4a5827
...
@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
...
@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto
conv
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
());
auto
conv
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
());
if
(
conv
.
op
.
group
>
1
)
if
(
conv
.
op
.
group
>
1
)
return
false
;
return
false
;
if
(
conv
.
op
.
padding_mode
!=
op
::
padding_mode_t
::
default_
)
return
false
;
if
(
wei
.
lens
()[
1
]
>
512
and
conv
.
algo
!=
miopenConvolutionFwdAlgoWinograd
)
if
(
wei
.
lens
()[
1
]
>
512
and
conv
.
algo
!=
miopenConvolutionFwdAlgoWinograd
)
return
false
;
return
false
;
auto
op
=
conv
.
op
;
auto
op
=
conv
.
op
;
...
@@ -251,6 +253,12 @@ struct miopen_conv_bias
...
@@ -251,6 +253,12 @@ struct miopen_conv_bias
fusion
::
op_t
conv
;
fusion
::
op_t
conv
;
fusion
::
op_t
bias
;
fusion
::
op_t
bias
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
op
::
convolution
::
reflect
(
self
.
op
,
f
);
}
miopen_conv_bias
(
op
::
convolution
c
,
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
b
)
miopen_conv_bias
(
op
::
convolution
c
,
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
b
)
:
op
(
c
),
f
(
input
)
:
op
(
c
),
f
(
input
)
{
{
...
@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
...
@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
fusion
::
op_t
bias
;
fusion
::
op_t
bias
;
fusion
::
op_t
relu
;
fusion
::
op_t
relu
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
op
::
convolution
::
reflect
(
self
.
op
,
f
);
}
miopen_conv_bias_relu
(
op
::
convolution
c
,
miopen_conv_bias_relu
(
op
::
convolution
c
,
const
shape
&
input
,
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
weights
,
...
...
test/eliminate_identity_test.cpp
View file @
0a4a5827
...
@@ -59,7 +59,7 @@ TEST_CASE(simple_test_end_dependency)
...
@@ -59,7 +59,7 @@ TEST_CASE(simple_test_end_dependency)
p
.
add_instruction
(
sum_op
{},
ans
,
three
);
p
.
add_instruction
(
sum_op
{},
ans
,
three
);
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
ans
);
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
ans
);
p
.
compile
(
eliminate_identity_target
{});
p
.
compile
(
eliminate_identity_target
{});
EXPECT
(
!
std
::
none
_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
EXPECT
(
std
::
any
_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"identity"
;
return
ins
.
name
()
==
"identity"
;
}));
}));
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
0a4a5827
...
@@ -3,9 +3,13 @@
...
@@ -3,9 +3,13 @@
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
bool
is_batch_norm
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"batch_norm_inference"
;
}
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
{
{
std
::
vector
<
float
>
xdata
=
{
std
::
vector
<
float
>
xdata
=
{
...
@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
...
@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
}
}
TEST_CASE
(
non_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
xs
);
auto
w
=
p
.
add_parameter
(
"w"
,
ws
);
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p2
,
&
is_batch_norm
));
}
TEST_CASE
(
as_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
TEST_CASE
(
literal_reshape
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
reshape
=
[
&
](
auto
ins
)
{
return
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
4
,
1
,
1
}},
ins
);
};
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
))));
auto
bias
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
))));
auto
mean
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
))));
auto
variance
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
))));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment