Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d2af0a8c
"torchvision/datasets/vscode:/vscode.git/clone" did not exist on "6662b30a14940db530cdfbff1c2169b6e8fd297d"
Commit
d2af0a8c
authored
Sep 15, 2018
by
Paul
Browse files
Rename arguments to use inputs function
parent
872ff12a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
45 additions
and
45 deletions
+45
-45
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+3
-3
src/fwd_conv_batchnorm_rewrite.cpp
src/fwd_conv_batchnorm_rewrite.cpp
+9
-9
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+3
-3
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+1
-1
src/program.cpp
src/program.cpp
+6
-6
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+3
-3
src/targets/cpu/cpu_lowering.cpp
src/targets/cpu/cpu_lowering.cpp
+5
-5
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+3
-3
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+11
-11
No files found.
src/dead_code_elimination.cpp
View file @
d2af0a8c
...
@@ -26,7 +26,7 @@ void dead_code_elimination::apply(program& p) const
...
@@ -26,7 +26,7 @@ void dead_code_elimination::apply(program& p) const
assert
(
p
.
has_instruction
(
leaf
));
assert
(
p
.
has_instruction
(
leaf
));
if
(
leaf
->
output
.
empty
())
if
(
leaf
->
output
.
empty
())
{
{
auto
args
=
leaf
->
arguments
;
auto
args
=
leaf
->
inputs
()
;
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
p
.
move_instruction
(
leaf
,
p
.
end
());
p
.
move_instruction
(
leaf
,
p
.
end
());
for
(
auto
arg
:
args
)
for
(
auto
arg
:
args
)
...
...
src/eliminate_contiguous.cpp
View file @
d2af0a8c
...
@@ -27,15 +27,15 @@ void eliminate_contiguous::apply(program& p) const
...
@@ -27,15 +27,15 @@ void eliminate_contiguous::apply(program& p) const
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
arguments
;
auto
args
=
ins
->
inputs
()
;
for
(
auto
arg
:
ins
->
arguments
)
for
(
auto
arg
:
ins
->
inputs
()
)
{
{
// TODO: Pass in names for the operator in the constructor instead
// TODO: Pass in names for the operator in the constructor instead
// of using ends_with
// of using ends_with
if
(
ends_with
(
arg
->
name
(),
"contiguous"
))
if
(
ends_with
(
arg
->
name
(),
"contiguous"
))
{
{
auto
new_args
=
args
;
auto
new_args
=
args
;
auto
prev
=
arg
->
arguments
.
front
();
auto
prev
=
arg
->
inputs
()
.
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
->
op
,
new_args
))
if
(
try_compute_shape
(
ins
->
op
,
new_args
))
{
{
...
...
src/fwd_conv_batchnorm_rewrite.cpp
View file @
d2af0a8c
...
@@ -12,26 +12,26 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -12,26 +12,26 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
if
(
not
std
::
all_of
(
ins
->
arguments
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
[](
auto
arg
)
{
if
(
not
std
::
all_of
(
ins
->
inputs
()
.
begin
()
+
1
,
ins
->
inputs
()
.
end
(),
[](
auto
arg
)
{
return
arg
->
name
()
==
"@literal"
;
return
arg
->
name
()
==
"@literal"
;
}))
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
arguments
[
0
];
auto
conv_ins
=
ins
->
inputs
()
[
0
];
if
(
conv_ins
->
name
()
!=
"convolution"
)
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
continue
;
if
(
conv_ins
->
arguments
[
1
]
->
name
()
!=
"@literal"
)
if
(
conv_ins
->
inputs
()
[
1
]
->
name
()
!=
"@literal"
)
continue
;
continue
;
// Get scale, bias, mean, variance from instruction_ref
// Get scale, bias, mean, variance from instruction_ref
const
auto
&
gamma
=
ins
->
arguments
[
1
]
->
get_literal
();
const
auto
&
gamma
=
ins
->
inputs
()
[
1
]
->
get_literal
();
const
auto
&
bias
=
ins
->
arguments
[
2
]
->
get_literal
();
const
auto
&
bias
=
ins
->
inputs
()
[
2
]
->
get_literal
();
const
auto
&
mean
=
ins
->
arguments
[
3
]
->
get_literal
();
const
auto
&
mean
=
ins
->
inputs
()
[
3
]
->
get_literal
();
const
auto
&
variance
=
ins
->
arguments
[
4
]
->
get_literal
();
const
auto
&
variance
=
ins
->
inputs
()
[
4
]
->
get_literal
();
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
bn_op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution weights
// Get convolution weights
const
auto
&
weights
=
conv_ins
->
arguments
[
1
]
->
get_literal
();
const
auto
&
weights
=
conv_ins
->
inputs
()
[
1
]
->
get_literal
();
// Get convolution op
// Get convolution op
auto
conv_op
=
conv_ins
->
op
;
auto
conv_op
=
conv_ins
->
op
;
auto
weights_lens
=
weights
.
get_shape
().
lens
();
auto
weights_lens
=
weights
.
get_shape
().
lens
();
...
@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -58,7 +58,7 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
// Replace convolution instruction with updated weights
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
arguments
[
0
],
l_weights
});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
inputs
()
[
0
],
l_weights
});
auto
b
=
p
.
insert_instruction
(
ins
,
broadcast
{
1
},
c
,
l_bias
);
auto
b
=
p
.
insert_instruction
(
ins
,
broadcast
{
1
},
c
,
l_bias
);
p
.
replace_instruction
(
ins
,
add
{},
{
c
,
b
});
p
.
replace_instruction
(
ins
,
add
{},
{
c
,
b
});
}
}
...
...
src/include/migraph/instruction.hpp
View file @
d2af0a8c
...
@@ -87,8 +87,8 @@ struct instruction
...
@@ -87,8 +87,8 @@ struct instruction
}
}
return
result
==
computed
&&
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
return
std
::
find
(
i
->
inputs
()
.
begin
(),
i
->
inputs
()
.
end
(),
*
this
)
!=
i
->
arguments
.
end
();
i
->
inputs
()
.
end
();
});
});
}
}
...
@@ -156,7 +156,7 @@ struct instruction
...
@@ -156,7 +156,7 @@ struct instruction
inline
void
backreference
(
instruction_ref
ref
)
inline
void
backreference
(
instruction_ref
ref
)
{
{
for
(
auto
&&
arg
:
ref
->
arguments
)
for
(
auto
&&
arg
:
ref
->
inputs
()
)
arg
->
add_output
(
ref
);
arg
->
add_output
(
ref
);
}
}
...
...
src/onnx/verify_onnx.cpp
View file @
d2af0a8c
...
@@ -68,7 +68,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
...
@@ -68,7 +68,7 @@ void verify_instructions(const migraph::program& prog, double tolerance = 80)
auto
create_program
=
[
&
]
{
auto
create_program
=
[
&
]
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
for
(
auto
&&
arg
:
ins
.
arguments
)
for
(
auto
&&
arg
:
ins
.
inputs
()
)
{
{
if
(
arg
->
name
()
==
"@literal"
)
if
(
arg
->
name
()
==
"@literal"
)
inputs
.
push_back
(
p
.
add_literal
(
arg
->
lit
));
inputs
.
push_back
(
p
.
add_literal
(
arg
->
lit
));
...
...
src/program.cpp
View file @
d2af0a8c
...
@@ -48,10 +48,10 @@ static void print_program(std::ostream& os, const program& p, F annonate)
...
@@ -48,10 +48,10 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os
<<
"{"
<<
ins
->
lit
<<
"}"
;
os
<<
"{"
<<
ins
->
lit
<<
"}"
;
}
}
if
(
!
ins
->
arguments
.
empty
())
if
(
!
ins
->
inputs
()
.
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
arguments
)
for
(
auto
&&
arg
:
ins
->
inputs
()
)
{
{
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
os
<<
delim
<<
names
.
at
(
arg
);
os
<<
delim
<<
names
.
at
(
arg
);
...
@@ -93,7 +93,7 @@ instruction_ref program::insert_instruction(instruction_ref ins,
...
@@ -93,7 +93,7 @@ instruction_ref program::insert_instruction(instruction_ref ins,
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
backreference
(
result
);
backreference
(
result
);
// assert(result->
arguments
== args);
// assert(result->
inputs()
== args);
assert
(
result
->
valid
(
begin
()));
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
@@ -300,9 +300,9 @@ argument generic_eval(const program& p,
...
@@ -300,9 +300,9 @@ argument generic_eval(const program& p,
}
}
else
else
{
{
values
.
resize
(
ins
->
arguments
.
size
());
values
.
resize
(
ins
->
inputs
()
.
size
());
std
::
transform
(
ins
->
arguments
.
begin
(),
std
::
transform
(
ins
->
inputs
()
.
begin
(),
ins
->
arguments
.
end
(),
ins
->
inputs
()
.
end
(),
values
.
begin
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
[
&
](
instruction_ref
i
)
{
assert
(
results
.
find
(
i
)
!=
results
.
end
());
assert
(
results
.
find
(
i
)
!=
results
.
end
());
...
...
src/simplify_reshapes.cpp
View file @
d2af0a8c
...
@@ -35,9 +35,9 @@ void simplify_reshapes::apply(program& p) const
...
@@ -35,9 +35,9 @@ void simplify_reshapes::apply(program& p) const
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()))
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()))
{
{
assert
(
!
reshapes
.
back
()
->
arguments
.
empty
());
assert
(
!
reshapes
.
back
()
->
inputs
()
.
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
arguments
.
front
()));
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
()
.
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
arguments
.
front
());
reshapes
.
push_back
(
reshapes
.
back
()
->
inputs
()
.
front
());
}
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
...
...
src/targets/cpu/cpu_lowering.cpp
View file @
d2af0a8c
...
@@ -597,30 +597,30 @@ struct cpu_apply
...
@@ -597,30 +597,30 @@ struct cpu_apply
template
<
class
T
>
template
<
class
T
>
void
apply_simple_op
(
instruction_ref
ins
)
void
apply_simple_op
(
instruction_ref
ins
)
{
{
prog
->
replace_instruction
(
ins
,
T
{},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
T
{},
ins
->
inputs
()
);
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
apply_extend_op
(
instruction_ref
ins
)
void
apply_extend_op
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
op
);
prog
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
inputs
()
);
}
}
void
apply_activation
(
instruction_ref
ins
)
void
apply_activation
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
op
);
if
(
op
.
mode
==
"relu"
)
if
(
op
.
mode
==
"relu"
)
prog
->
replace_instruction
(
ins
,
cpu_unary
<
relu_op
>
{},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_unary
<
relu_op
>
{},
ins
->
inputs
()
);
}
}
void
apply_pooling
(
instruction_ref
ins
)
void
apply_pooling
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
op
);
if
(
op
.
mode
==
"max"
)
if
(
op
.
mode
==
"max"
)
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
max_pool
>
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
max_pool
>
{
op
},
ins
->
inputs
()
);
else
if
(
op
.
mode
==
"average"
)
else
if
(
op
.
mode
==
"average"
)
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
avg_pool
>
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
avg_pool
>
{
op
},
ins
->
inputs
()
);
}
}
};
};
...
...
src/targets/gpu/fuse_ops.cpp
View file @
d2af0a8c
...
@@ -28,12 +28,12 @@ void fuse_ops::apply(program& p) const
...
@@ -28,12 +28,12 @@ void fuse_ops::apply(program& p) const
{
{
if
(
ins
->
name
()
!=
"gpu::relu"
)
if
(
ins
->
name
()
!=
"gpu::relu"
)
continue
;
continue
;
auto
add_ins
=
ins
->
arguments
.
front
();
auto
add_ins
=
ins
->
inputs
()
.
front
();
if
(
add_ins
->
name
()
!=
"gpu::add"
)
if
(
add_ins
->
name
()
!=
"gpu::add"
)
continue
;
continue
;
auto
args
=
add_ins
->
arguments
;
auto
args
=
add_ins
->
inputs
()
;
// Use the allocation from the relu operator
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
arguments
.
back
();
args
.
back
()
=
ins
->
inputs
()
.
back
();
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
}
}
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
d2af0a8c
...
@@ -372,13 +372,13 @@ struct miopen_apply
...
@@ -372,13 +372,13 @@ struct miopen_apply
auto
&&
op
=
any_cast
<
convolution
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
convolution
>
(
ins
->
op
);
auto
conv
=
miopen_convolution
{
op
,
make_conv
(
op
)};
auto
conv
=
miopen_convolution
{
op
,
make_conv
(
op
)};
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
ins
->
arguments
);
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
ins
->
inputs
()
);
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
conv
,
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
workspace
,
output
);
ins
,
conv
,
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
workspace
,
output
);
}
}
instruction_ref
apply_pooling
(
instruction_ref
ins
)
instruction_ref
apply_pooling
(
instruction_ref
ins
)
...
@@ -388,7 +388,7 @@ struct miopen_apply
...
@@ -388,7 +388,7 @@ struct miopen_apply
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
instruction_ref
apply_activation
(
instruction_ref
ins
)
instruction_ref
apply_activation
(
instruction_ref
ins
)
...
@@ -399,7 +399,7 @@ struct miopen_apply
...
@@ -399,7 +399,7 @@ struct miopen_apply
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
return
ins
;
return
ins
;
}
}
...
@@ -408,7 +408,7 @@ struct miopen_apply
...
@@ -408,7 +408,7 @@ struct miopen_apply
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
hip_add
{},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
hip_add
{},
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
output
);
}
}
instruction_ref
apply_gemm
(
instruction_ref
ins
)
instruction_ref
apply_gemm
(
instruction_ref
ins
)
...
@@ -416,31 +416,31 @@ struct miopen_apply
...
@@ -416,31 +416,31 @@ struct miopen_apply
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
miopen_gemm
{
op
},
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
output
);
}
}
instruction_ref
apply_contiguous
(
instruction_ref
ins
)
instruction_ref
apply_contiguous
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
return
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
instruction_ref
apply_batch_norm_inference
(
instruction_ref
ins
)
instruction_ref
apply_batch_norm_inference
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
shape
old_shape
=
ins
->
arguments
.
at
(
1
)
->
get_shape
();
shape
old_shape
=
ins
->
inputs
()
.
at
(
1
)
->
get_shape
();
std
::
vector
<
int64_t
>
new_shape
{
1
,
static_cast
<
int64_t
>
(
old_shape
.
elements
()),
1
,
1
};
std
::
vector
<
int64_t
>
new_shape
{
1
,
static_cast
<
int64_t
>
(
old_shape
.
elements
()),
1
,
1
};
auto
reshape_op
=
reshape
{
new_shape
};
auto
reshape_op
=
reshape
{
new_shape
};
std
::
vector
<
instruction_ref
>
reshapes
;
std
::
vector
<
instruction_ref
>
reshapes
;
std
::
transform
(
ins
->
arguments
.
begin
()
+
1
,
std
::
transform
(
ins
->
inputs
()
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
ins
->
inputs
()
.
end
(),
std
::
back_inserter
(
reshapes
),
std
::
back_inserter
(
reshapes
),
[
&
](
auto
i
)
{
return
prog
->
insert_instruction
(
ins
,
reshape_op
,
i
);
});
[
&
](
auto
i
)
{
return
prog
->
insert_instruction
(
ins
,
reshape_op
,
i
);
});
return
prog
->
replace_instruction
(
ins
,
return
prog
->
replace_instruction
(
ins
,
miopen_batch_norm_inference
{
op
},
miopen_batch_norm_inference
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
inputs
()
.
at
(
0
),
reshapes
[
0
],
reshapes
[
0
],
reshapes
[
1
],
reshapes
[
1
],
reshapes
[
2
],
reshapes
[
2
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment