Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8f08607e
Unverified
Commit
8f08607e
authored
Mar 29, 2019
by
mvermeulen
Committed by
GitHub
Mar 29, 2019
Browse files
Merge branch 'develop' into rm_identity
parents
e2cf822d
15ba8a36
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
182 additions
and
20 deletions
+182
-20
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+14
-15
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+12
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+24
-0
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-0
src/schedule.cpp
src/schedule.cpp
+2
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+4
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+14
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-1
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+105
-0
No files found.
src/fwd_conv_batchnorm_rewrite.cpp
View file @
8f08607e
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -14,32 +15,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
if
(
not
std
::
all_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[](
auto
arg
)
{
// Get scale, bias, mean, variance from inputs
return
arg
->
name
()
==
"@literal"
;
auto
gamma
=
ins
->
inputs
()[
1
]
->
eval
();
}))
auto
bias
=
ins
->
inputs
()[
2
]
->
eval
();
auto
mean
=
ins
->
inputs
()[
3
]
->
eval
();
auto
variance
=
ins
->
inputs
()[
4
]
->
eval
();
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
conv_ins
=
ins
->
inputs
()[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
continue
;
if
(
conv_ins
->
inputs
()[
1
]
->
name
()
!=
"@literal"
)
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
continue
;
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
inputs
()[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()[
4
]
->
get_literal
();
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
inputs
()[
1
]
->
get_literal
();
// Get convolution op
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
auto
conv_op
=
conv_ins
->
get_operator
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
argument
new_weights
{
weights
.
get_shape
()};
argument
new_weights
{
weights
.
get_shape
()};
argument
new_bias
{
bias
.
get_shape
()};
argument
new_bias
{
{
bias
.
get_shape
()
.
type
(),
{
bias
.
get_shape
().
elements
()}}
};
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
auto
weights2
,
[
&
](
auto
weights2
,
auto
gamma2
,
auto
gamma2
,
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -51,11 +50,11 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
(
k
)
/
std
::
sqrt
(
variance2
(
k
)
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
(
c
)
=
new_bias2
[
c
]
=
bias2
(
c
)
-
(
gamma2
(
c
)
*
mean2
(
c
)
/
std
::
sqrt
(
variance2
(
c
)
+
epsilon
));
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
// Replace convolution instruction with updated weights
...
...
src/include/migraphx/check_shapes.hpp
View file @
8f08607e
...
@@ -18,6 +18,11 @@ struct check_shapes
...
@@ -18,6 +18,11 @@ struct check_shapes
{
{
}
}
template
<
class
Op
>
check_shapes
(
const
shape
*
b
,
const
shape
*
e
,
const
Op
&
op
)
:
begin
(
b
),
end
(
e
),
name
(
op
.
name
())
{
}
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
())
{}
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
begin
(
s
.
data
()),
end
(
s
.
data
()
+
s
.
size
())
{}
template
<
class
Op
>
template
<
class
Op
>
...
@@ -119,6 +124,13 @@ struct check_shapes
...
@@ -119,6 +124,13 @@ struct check_shapes
return
*
this
;
return
*
this
;
}
}
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
}
template
<
class
F
>
template
<
class
F
>
bool
same
(
F
f
)
const
bool
same
(
F
f
)
const
{
{
...
...
src/include/migraphx/operators.hpp
View file @
8f08607e
...
@@ -56,6 +56,9 @@ struct batch_norm_inference
...
@@ -56,6 +56,9 @@ struct batch_norm_inference
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
.
data
(),
inputs
.
data
()
+
1
,
*
this
}.
only_dims
(
4
);
check_shapes
{
inputs
.
data
()
+
1
,
inputs
.
data
()
+
inputs
.
size
(),
*
this
}.
same_shape
().
elements
(
inputs
.
front
().
lens
()[
1
]);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
};
};
...
...
src/include/migraphx/ranges.hpp
View file @
8f08607e
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
...
@@ -71,6 +71,30 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
return
std
::
all_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
}
template
<
class
C
,
class
Predicate
>
bool
any_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
any_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
any_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
C
,
class
Predicate
>
bool
none_of
(
const
C
&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
T
,
class
Predicate
>
bool
none_of
(
const
std
::
initializer_list
<
T
>&
c
,
const
Predicate
&
p
)
{
return
std
::
none_of
(
c
.
begin
(),
c
.
end
(),
p
);
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
...
...
src/include/migraphx/schedule.hpp
View file @
8f08607e
...
@@ -17,6 +17,7 @@ struct program;
...
@@ -17,6 +17,7 @@ struct program;
struct
schedule
struct
schedule
{
{
schedule_model
model
{};
schedule_model
model
{};
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/schedule.cpp
View file @
8f08607e
...
@@ -341,6 +341,8 @@ struct stream_info
...
@@ -341,6 +341,8 @@ struct stream_info
void
schedule
::
apply
(
program
&
p
)
const
void
schedule
::
apply
(
program
&
p
)
const
{
{
if
(
not
enable
)
return
;
stream_info
si
;
stream_info
si
;
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
p
.
end
());
si
.
accumulate_weights
(
last
,
model
);
si
.
accumulate_weights
(
last
,
model
);
...
...
src/targets/cpu/lowering.cpp
View file @
8f08607e
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
...
@@ -75,10 +75,10 @@ struct cpu_batch_norm_inference
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
par_dfor
(
num_batch
,
num_channels
,
image_height
,
image_width
)(
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
[
&
](
std
::
size_t
n
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
assert
((
variance
(
c
)
+
epsilon
)
>
0
);
assert
((
variance
[
c
]
+
epsilon
)
>
0
);
result
(
n
,
c
,
h
,
w
)
=
gamma
(
c
)
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
(
c
)
)
/
result
(
n
,
c
,
h
,
w
)
=
gamma
[
c
]
*
(
buffer
(
n
,
c
,
h
,
w
)
-
mean
[
c
]
)
/
std
::
sqrt
(
variance
(
c
)
+
epsilon
)
+
std
::
sqrt
(
variance
[
c
]
+
epsilon
)
+
bias
(
c
)
;
bias
[
c
]
;
});
});
});
});
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
8f08607e
...
@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
...
@@ -140,6 +140,8 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto
conv
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
());
auto
conv
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
());
if
(
conv
.
op
.
group
>
1
)
if
(
conv
.
op
.
group
>
1
)
return
false
;
return
false
;
if
(
conv
.
op
.
padding_mode
!=
op
::
padding_mode_t
::
default_
)
return
false
;
if
(
wei
.
lens
()[
1
]
>
512
and
conv
.
algo
!=
miopenConvolutionFwdAlgoWinograd
)
if
(
wei
.
lens
()[
1
]
>
512
and
conv
.
algo
!=
miopenConvolutionFwdAlgoWinograd
)
return
false
;
return
false
;
auto
op
=
conv
.
op
;
auto
op
=
conv
.
op
;
...
@@ -251,6 +253,12 @@ struct miopen_conv_bias
...
@@ -251,6 +253,12 @@ struct miopen_conv_bias
fusion
::
op_t
conv
;
fusion
::
op_t
conv
;
fusion
::
op_t
bias
;
fusion
::
op_t
bias
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
op
::
convolution
::
reflect
(
self
.
op
,
f
);
}
miopen_conv_bias
(
op
::
convolution
c
,
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
b
)
miopen_conv_bias
(
op
::
convolution
c
,
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
b
)
:
op
(
c
),
f
(
input
)
:
op
(
c
),
f
(
input
)
{
{
...
@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
...
@@ -288,6 +296,12 @@ struct miopen_conv_bias_relu
fusion
::
op_t
bias
;
fusion
::
op_t
bias
;
fusion
::
op_t
relu
;
fusion
::
op_t
relu
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
op
::
convolution
::
reflect
(
self
.
op
,
f
);
}
miopen_conv_bias_relu
(
op
::
convolution
c
,
miopen_conv_bias_relu
(
op
::
convolution
c
,
const
shape
&
input
,
const
shape
&
input
,
const
shape
&
weights
,
const
shape
&
weights
,
...
...
src/targets/gpu/target.cpp
View file @
8f08607e
...
@@ -26,6 +26,8 @@ namespace migraphx {
...
@@ -26,6 +26,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_SCHEDULE_PASS
)
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
gctx
)
const
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
gctx
)
const
{
{
auto
&
ctx
=
any_cast
<
context
>
(
gctx
);
auto
&
ctx
=
any_cast
<
context
>
(
gctx
);
...
@@ -55,7 +57,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -55,7 +57,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fuse_ops
{
&
ctx
},
fuse_ops
{
&
ctx
},
dead_code_elimination
{},
dead_code_elimination
{},
write_literals
{
&
ctx
},
write_literals
{
&
ctx
},
schedule
{
gpu
::
schedule_model
{
ctx
.
get_current_device
().
nstreams
()}},
schedule
{
gpu
::
schedule_model
{
ctx
.
get_current_device
().
nstreams
()}
,
enabled
(
MIGRAPHX_ENABLE_SCHEDULE_PASS
{})
},
memory_coloring
{
"hip::allocate"
},
memory_coloring
{
"hip::allocate"
},
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_workspace
{},
eliminate_workspace
{},
...
...
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
8f08607e
...
@@ -3,9 +3,13 @@
...
@@ -3,9 +3,13 @@
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <test.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
bool
is_batch_norm
(
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"batch_norm_inference"
;
}
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
{
{
std
::
vector
<
float
>
xdata
=
{
std
::
vector
<
float
>
xdata
=
{
...
@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
...
@@ -65,4 +69,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
EXPECT
(
migraphx
::
verify_range
(
results_vector1
,
results_vector2
));
}
}
TEST_CASE
(
non_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
xs
);
auto
w
=
p
.
add_parameter
(
"w"
,
ws
);
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p2
,
&
is_batch_norm
));
}
TEST_CASE
(
as_literal
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
TEST_CASE
(
literal_reshape
)
{
migraphx
::
shape
xs
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}};
migraphx
::
shape
ws
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
1
}};
migraphx
::
shape
vars
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
create_program
=
[
&
]()
{
migraphx
::
program
p
;
auto
reshape
=
[
&
](
auto
ins
)
{
return
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
4
,
1
,
1
}},
ins
);
};
auto
x
=
p
.
add_literal
(
migraphx
::
generate_literal
(
xs
,
1
));
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
(
ws
,
1
));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{},
x
,
w
);
auto
scale
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
1
))));
auto
bias
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
2
))));
auto
mean
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
3
))));
auto
variance
=
reshape
(
p
.
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
vars
,
4
))));
p
.
add_instruction
(
migraphx
::
op
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
};
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv_batchnorm_rewrite
opt
;
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result1
=
p1
.
eval
({});
auto
result2
=
p2
.
eval
({});
visit_all
(
result1
,
result2
)([
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment