Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d2af0a8c
Commit
d2af0a8c
authored
Sep 15, 2018
by
Paul
Browse files
Rename arguments to use inputs function
parent
872ff12a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
45 additions
and
45 deletions
+45
-45
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+3
-3
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+9
-9
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+3
-3
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+1
-1
src/program.cpp
src/program.cpp
+6
-6
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+3
-3
src/targets/cpu/cpu_lowering.cpp
src/targets/cpu/cpu_lowering.cpp
+5
-5
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+3
-3
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+11
-11
No files found.
src/dead_code_elimination.cpp
View file @
d2af0a8c
...
@@ -26,7 +26,7 @@ void dead_code_elimination::apply(program& p) const
...
@@ -26,7 +26,7 @@ void dead_code_elimination::apply(program& p) const
assert
(
p
.
has_instruction
(
leaf
));
assert
(
p
.
has_instruction
(
leaf
));
if
(
leaf
->
output
.
empty
())
if
(
leaf
->
output
.
empty
())
{
{
auto
args
=
leaf
->
arguments
;
auto
args
=
leaf
->
inputs
()
;
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
p
.
move_instruction
(
leaf
,
p
.
end
());
p
.
move_instruction
(
leaf
,
p
.
end
());
for
(
auto
arg
:
args
)
for
(
auto
arg
:
args
)
...
...
src/eliminate_contiguous.cpp
View file @
d2af0a8c
...
@@ -27,15 +27,15 @@ void eliminate_contiguous::apply(program& p) const
...
@@ -27,15 +27,15 @@ void eliminate_contiguous::apply(program& p) const
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
arguments
;
auto
args
=
ins
->
inputs
()
;
for
(
auto
arg
:
ins
->
arguments
)
for
(
auto
arg
:
ins
->
inputs
()
)
{
{
// TODO: Pass in names for the operator in the constructor instead
// TODO: Pass in names for the operator in the constructor instead
// of using ends_with
// of using ends_with
if
(
ends_with
(
arg
->
name
(),
"contiguous"
))
if
(
ends_with
(
arg
->
name
(),
"contiguous"
))
{
{
auto
new_args
=
args
;
auto
new_args
=
args
;
auto
prev
=
arg
->
arguments
.
front
();
auto
prev
=
arg
->
inputs
()
.
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
->
op
,
new_args
))
if
(
try_compute_shape
(
ins
->
op
,
new_args
))
{
{
...
...
src/fwd_conv_batchnorm_rewrite.cpp
View file @
d2af0a8c
...
@@ -12,26 +12,26 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -12,26 +12,26 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
if
(
not
std
::
all_of
(
ins
->
arguments
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
[](
auto
arg
)
{
if
(
not
std
::
all_of
(
ins
->
inputs
()
.
begin
()
+
1
,
ins
->
inputs
()
.
end
(),
[](
auto
arg
)
{
return
arg
->
name
()
==
"@literal"
;
return
arg
->
name
()
==
"@literal"
;
}))
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
arguments
[
0
];
auto
conv_ins
=
ins
->
inputs
()
[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
continue
;
if
(
conv_ins
->
arguments
[
1
]
->
name
()
!=
"@literal"
)
if
(
conv_ins
->
inputs
()
[
1
]
->
name
()
!=
"@literal"
)
continue
;
continue
;
// Get scale, bias, mean, variance from instruction_ref
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
arguments
[
1
]
->
get_literal
();
const
auto
&
gamma
=
ins
->
inputs
()
[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
arguments
[
2
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()
[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
arguments
[
3
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()
[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
arguments
[
4
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()
[
4
]
->
get_literal
();
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
bn_op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
arguments
[
1
]
->
get_literal
();
const
auto
&
weights
=
conv_ins
->
inputs
()
[
1
]
->
get_literal
();
// Get convolution op
// Get convolution op
auto
conv_op
=
conv_ins
->
op
;
auto
conv_op
=
conv_ins
->
op
;
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
...
@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
arguments
[
0
],
l_weights
});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
inputs
()
[
0
],
l_weights
});
auto
b
=
p
.
insert_instruction
(
ins
,
broadcast
{
1
},
c
,
l_bias
);
auto
b
=
p
.
insert_instruction
(
ins
,
broadcast
{
1
},
c
,
l_bias
);
p
.
replace_instruction
(
ins
,
add
{},
{
c
,
b
});
p
.
replace_instruction
(
ins
,
add
{},
{
c
,
b
});
}
}
...
...
src/include/migraph/instruction.hpp
View file @
d2af0a8c
...
@@ -87,8 +87,8 @@ struct instruction
...
@@ -87,8 +87,8 @@ struct instruction
}
}
return
result
==
computed
&&
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
return
std
::
find
(
i
->
inputs
()
.
begin
(),
i
->
inputs
()
.
end
(),
*
this
)
!=
i
->
arguments
.
end
();
i
->
inputs
()
.
end
();
});
});
}
}
...
@@ -156,7 +156,7 @@ struct instruction
...
@@ -156,7 +156,7 @@ struct instruction
inline
void
backreference
(
instruction_ref
ref
)
inline
void
backreference
(
instruction_ref
ref
)
{
{
for
(
auto
&&
arg
:
ref
->
arguments
)
for
(
auto
&&
arg
:
ref
->
inputs
()
)
arg
->
add_output
(
ref
);
arg
->
add_output
(
ref
);
}
}
...
...
src/onnx/verify_onnx.cpp
View file @
d2af0a8c
...
@@ -68,7 +68,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
...
@@ -68,7 +68,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
auto
create_program
=
[
&
]
{
auto
create_program
=
[
&
]
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
for
(
auto
&&
arg
:
ins
.
arguments
)
for
(
auto
&&
arg
:
ins
.
inputs
()
)
{
{
if
(
arg
->
name
()
==
"@literal"
)
if
(
arg
->
name
()
==
"@literal"
)
inputs
.
push_back
(
p
.
add_literal
(
arg
->
lit
));
inputs
.
push_back
(
p
.
add_literal
(
arg
->
lit
));
...
...
src/program.cpp
View file @
d2af0a8c
...
@@ -48,10 +48,10 @@ static void print_program(std::ostream& os, const program& p, F annonate)
...
@@ -48,10 +48,10 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os
<<
"{"
<<
ins
->
lit
<<
"}"
;
os
<<
"{"
<<
ins
->
lit
<<
"}"
;
}
}
if
(
!
ins
->
arguments
.
empty
())
if
(
!
ins
->
inputs
()
.
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
arguments
)
for
(
auto
&&
arg
:
ins
->
inputs
()
)
{
{
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
os
<<
delim
<<
names
.
at
(
arg
);
os
<<
delim
<<
names
.
at
(
arg
);
...
@@ -93,7 +93,7 @@ instruction_ref program::insert_instruction(instruction_ref ins,
...
@@ -93,7 +93,7 @@ instruction_ref program::insert_instruction(instruction_ref ins,
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
backreference
(
result
);
backreference
(
result
);
// assert(result->
arguments
== args);
// assert(result->
inputs()
== args);
assert
(
result
->
valid
(
begin
()));
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
@@ -300,9 +300,9 @@ argument generic_eval(const program& p,
...
@@ -300,9 +300,9 @@ argument generic_eval(const program& p,
}
}
else
else
{
{
values
.
resize
(
ins
->
arguments
.
size
());
values
.
resize
(
ins
->
inputs
()
.
size
());
std
::
transform
(
ins
->
arguments
.
begin
(),
std
::
transform
(
ins
->
inputs
()
.
begin
(),
ins
->
arguments
.
end
(),
ins
->
inputs
()
.
end
(),
values
.
begin
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
[
&
](
instruction_ref
i
)
{
assert
(
results
.
find
(
i
)
!=
results
.
end
());
assert
(
results
.
find
(
i
)
!=
results
.
end
());
...
...
src/simplify_reshapes.cpp
View file @
d2af0a8c
...
@@ -35,9 +35,9 @@ void simplify_reshapes::apply(program& p) const
...
@@ -35,9 +35,9 @@ void simplify_reshapes::apply(program& p) const
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()))
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()))
{
{
assert
(
!
reshapes
.
back
()
->
arguments
.
empty
());
assert
(
!
reshapes
.
back
()
->
inputs
()
.
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
arguments
.
front
()));
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
()
.
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
arguments
.
front
());
reshapes
.
push_back
(
reshapes
.
back
()
->
inputs
()
.
front
());
}
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
...
...
src/targets/cpu/cpu_lowering.cpp
View file @
d2af0a8c
...
@@ -597,30 +597,30 @@ struct cpu_apply
...
@@ -597,30 +597,30 @@ struct cpu_apply
template
<
class
T
>
template
<
class
T
>
void
apply_simple_op
(
instruction_ref
ins
)
void
apply_simple_op
(
instruction_ref
ins
)
{
{
prog
->
replace_instruction
(
ins
,
T
{},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
T
{},
ins
->
inputs
()
);
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
apply_extend_op
(
instruction_ref
ins
)
void
apply_extend_op
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
op
);
prog
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
inputs
()
);
}
}
void
apply_activation
(
instruction_ref
ins
)
void
apply_activation
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
op
);
if
(
op
.
mode
==
"relu"
)
if
(
op
.
mode
==
"relu"
)
prog
->
replace_instruction
(
ins
,
cpu_unary
<
relu_op
>
{},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_unary
<
relu_op
>
{},
ins
->
inputs
()
);
}
}
void
apply_pooling
(
instruction_ref
ins
)
void
apply_pooling
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
op
);
if
(
op
.
mode
==
"max"
)
if
(
op
.
mode
==
"max"
)
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
max_pool
>
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
max_pool
>
{
op
},
ins
->
inputs
()
);
else
if
(
op
.
mode
==
"average"
)
else
if
(
op
.
mode
==
"average"
)
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
avg_pool
>
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
avg_pool
>
{
op
},
ins
->
inputs
()
);
}
}
};
};
...
...
src/targets/gpu/fuse_ops.cpp
View file @
d2af0a8c
...
@@ -28,12 +28,12 @@ void fuse_ops::apply(program& p) const
...
@@ -28,12 +28,12 @@ void fuse_ops::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"gpu::relu"
)
if
(
ins
->
name
()
!=
"gpu::relu"
)
continue
;
continue
;
auto
add_ins
=
ins
->
arguments
.
front
();
auto
add_ins
=
ins
->
inputs
()
.
front
();
if
(
add_ins
->
name
()
!=
"gpu::add"
)
if
(
add_ins
->
name
()
!=
"gpu::add"
)
continue
;
continue
;
auto
args
=
add_ins
->
arguments
;
auto
args
=
add_ins
->
inputs
()
;
// Use the allocation from the relu operator
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
arguments
.
back
();
args
.
back
()
=
ins
->
inputs
()
.
back
();
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
}
}
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
d2af0a8c
...
@@ -372,13 +372,13 @@ struct miopen_apply
...
@@ -372,13 +372,13 @@ struct miopen_apply
auto
&&
op
=
any_cast
<
convolution
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
convolution
>
(
ins
->
op
);
auto
conv
=
miopen_convolution
{
op
,
make_conv
(
op
)};
auto
conv
=
miopen_convolution
{
op
,
make_conv
(
op
)};
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
ins
->
arguments
);
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
ins
->
inputs
()
);
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
conv
,
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
workspace
,
output
);
ins
,
conv
,
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
workspace
,
output
);
}
}
instruction_ref
apply_pooling
(
instruction_ref
ins
)
instruction_ref
apply_pooling
(
instruction_ref
ins
)
...
@@ -388,7 +388,7 @@ struct miopen_apply
...
@@ -388,7 +388,7 @@ struct miopen_apply
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
instruction_ref
apply_activation
(
instruction_ref
ins
)
instruction_ref
apply_activation
(
instruction_ref
ins
)
...
@@ -399,7 +399,7 @@ struct miopen_apply
...
@@ -399,7 +399,7 @@ struct miopen_apply
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
return
ins
;
return
ins
;
}
}
...
@@ -408,7 +408,7 @@ struct miopen_apply
...
@@ -408,7 +408,7 @@ struct miopen_apply
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
hip_add
{},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
hip_add
{},
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
output
);
}
}
instruction_ref
apply_gemm
(
instruction_ref
ins
)
instruction_ref
apply_gemm
(
instruction_ref
ins
)
...
@@ -416,31 +416,31 @@ struct miopen_apply
...
@@ -416,31 +416,31 @@ struct miopen_apply
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
miopen_gemm
{
op
},
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
output
);
}
}
instruction_ref
apply_contiguous
(
instruction_ref
ins
)
instruction_ref
apply_contiguous
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
return
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
instruction_ref
apply_batch_norm_inference
(
instruction_ref
ins
)
instruction_ref
apply_batch_norm_inference
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
shape
old_shape
=
ins
->
arguments
.
at
(
1
)
->
get_shape
();
shape
old_shape
=
ins
->
inputs
()
.
at
(
1
)
->
get_shape
();
std
::
vector
<
int64_t
>
new_shape
{
1
,
static_cast
<
int64_t
>
(
old_shape
.
elements
()),
1
,
1
};
std
::
vector
<
int64_t
>
new_shape
{
1
,
static_cast
<
int64_t
>
(
old_shape
.
elements
()),
1
,
1
};
auto
reshape_op
=
reshape
{
new_shape
};
auto
reshape_op
=
reshape
{
new_shape
};
std
::
vector
<
instruction_ref
>
reshapes
;
std
::
vector
<
instruction_ref
>
reshapes
;
std
::
transform
(
ins
->
arguments
.
begin
()
+
1
,
std
::
transform
(
ins
->
inputs
()
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
ins
->
inputs
()
.
end
(),
std
::
back_inserter
(
reshapes
),
std
::
back_inserter
(
reshapes
),
[
&
](
auto
i
)
{
return
prog
->
insert_instruction
(
ins
,
reshape_op
,
i
);
});
[
&
](
auto
i
)
{
return
prog
->
insert_instruction
(
ins
,
reshape_op
,
i
);
});
return
prog
->
replace_instruction
(
ins
,
return
prog
->
replace_instruction
(
ins
,
miopen_batch_norm_inference
{
op
},
miopen_batch_norm_inference
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
inputs
()
.
at
(
0
),
reshapes
[
0
],
reshapes
[
0
],
reshapes
[
1
],
reshapes
[
1
],
reshapes
[
2
],
reshapes
[
2
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment