Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
22cee7ff
Commit
22cee7ff
authored
May 03, 2023
by
Paul
Browse files
Format
parent
d0dbaf41
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
64 deletions
+77
-64
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+1
-4
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+12
-12
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+5
-1
src/instruction.cpp
src/instruction.cpp
+3
-1
src/module.cpp
src/module.cpp
+8
-4
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+48
-42
No files found.
src/include/migraphx/common_dims.hpp
View file @
22cee7ff
...
@@ -12,10 +12,7 @@ struct common_dims
...
@@ -12,10 +12,7 @@ struct common_dims
{
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
const
std
::
vector
<
std
::
size_t
>&
dims2
);
bool
empty
()
const
bool
empty
()
const
{
return
dims
.
empty
();
}
{
return
dims
.
empty
();
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
...
...
src/include/migraphx/matcher.hpp
View file @
22cee7ff
...
@@ -203,18 +203,18 @@ struct basic_matcher
...
@@ -203,18 +203,18 @@ struct basic_matcher
{
{
// Copy m because we cant capture `this` by value
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
auto
mm
=
m
;
return
make_basic_fun_matcher
(
[
=
](
matcher_context
&
ctx
,
return
make_basic_fun_matcher
(
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
[
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
if
(
result
)
{
{
bool
matches
=
bool
matches
=
+
fold
(
+
fold
(
[
&
](
auto
x
,
auto
y
)
{
return
x
and
ctx
.
matched
(
y
,
result
);
})(
true
,
ms
...);
[
&
](
auto
x
,
auto
y
)
{
return
x
and
ctx
.
matched
(
y
,
result
);
})(
true
,
ms
...);
if
(
matches
)
if
(
matches
)
return
result
;
return
result
;
}
}
return
nullopt
;
return
nullopt
;
});
});
}
}
auto
bind
(
std
::
string
name
)
const
{
return
bind_match
(
m
,
std
::
move
(
name
));
}
auto
bind
(
std
::
string
name
)
const
{
return
bind_match
(
m
,
std
::
move
(
name
));
}
...
...
src/include/migraphx/module.hpp
View file @
22cee7ff
...
@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
...
@@ -54,7 +54,11 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/
*/
struct
module
struct
module
{
{
using
inserter
=
std
::
function
<
instruction_ref
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
module_ref
>&
module_args
)
>
;
using
inserter
=
std
::
function
<
instruction_ref
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
module_ref
>&
module_args
)
>
;
module
(
const
std
::
string
&
name
=
""
);
module
(
const
std
::
string
&
name
=
""
);
// move constructor
// move constructor
...
...
src/instruction.cpp
100755 → 100644
View file @
22cee7ff
...
@@ -328,7 +328,9 @@ bool instruction::can_eval() const
...
@@ -328,7 +328,9 @@ bool instruction::can_eval() const
}
}
else
if
(
is_context_free
(
op
))
else
if
(
is_context_free
(
op
))
{
{
assert
(
std
::
none_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[
&
](
instruction_ref
arg
)
{
return
std
::
addressof
(
*
arg
)
==
this
;
}));
assert
(
std
::
none_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[
&
](
instruction_ref
arg
)
{
return
std
::
addressof
(
*
arg
)
==
this
;
}));
return
std
::
all_of
(
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
can_eval
();
});
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
can_eval
();
});
}
}
...
...
src/module.cpp
View file @
22cee7ff
...
@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins,
...
@@ -430,11 +430,13 @@ module::insert_instructions(instruction_ref ins,
const
std
::
vector
<
instruction_ref
>&
instructions
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
),
default_module_inserter
());
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
),
default_module_inserter
());
}
}
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
module
::
insert_instructions
(
module
::
inserter
insert
,
instruction_ref
ins
,
module
::
insert_instructions
(
module
::
inserter
insert
,
instruction_ref
ins
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
...
@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins,
...
@@ -446,7 +448,8 @@ module::insert_instructions(instruction_ref ins,
const_module_ref
m
,
const_module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
*
m
),
std
::
move
(
map_ins
),
default_module_inserter
());
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
*
m
),
std
::
move
(
map_ins
),
default_module_inserter
());
}
}
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
...
@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins,
...
@@ -456,7 +459,8 @@ module::insert_instructions(instruction_ref ins,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
auto
r
=
range
(
start
,
last
);
auto
r
=
range
(
start
,
last
);
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
),
default_module_inserter
());
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
),
default_module_inserter
());
}
}
instruction_ref
module
::
add_literal
(
literal
l
)
{
return
insert_literal
(
begin
(),
std
::
move
(
l
));
}
instruction_ref
module
::
add_literal
(
literal
l
)
{
return
insert_literal
(
begin
(),
std
::
move
(
l
));
}
...
...
src/simplify_reshapes.cpp
View file @
22cee7ff
...
@@ -917,10 +917,11 @@ struct find_broadcast_reshaper
...
@@ -917,10 +917,11 @@ struct find_broadcast_reshaper
struct
find_poinwise_reduce_reshape
struct
find_poinwise_reduce_reshape
{
{
template
<
class
...
Ms
>
template
<
class
...
Ms
>
static
auto
match_reshaper
(
Ms
...
ms
)
static
auto
match_reshaper
(
Ms
...
ms
)
{
{
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
))(
ms
...)));
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
))(
ms
...)));
}
}
auto
matcher
()
const
auto
matcher
()
const
{
{
...
@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape
...
@@ -935,10 +936,7 @@ struct find_poinwise_reduce_reshape
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
());
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
());
}
}
static
bool
is_broadcast
(
instruction_ref
ins
)
static
bool
is_broadcast
(
instruction_ref
ins
)
{
return
is_broadcast
(
ins
->
get_operator
());
}
{
return
is_broadcast
(
ins
->
get_operator
());
}
static
bool
is_pointwise
(
instruction_ref
ins
)
static
bool
is_pointwise
(
instruction_ref
ins
)
{
{
...
@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape
...
@@ -946,10 +944,7 @@ struct find_poinwise_reduce_reshape
return
a
.
get
(
"pointwise"
,
false
);
return
a
.
get
(
"pointwise"
,
false
);
}
}
static
bool
is_reduce
(
instruction_ref
ins
)
static
bool
is_reduce
(
instruction_ref
ins
)
{
return
is_reduce
(
ins
->
get_operator
());
}
{
return
is_reduce
(
ins
->
get_operator
());
}
static
bool
is_reduce
(
const
operation
&
op
)
static
bool
is_reduce
(
const
operation
&
op
)
{
{
...
@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape
...
@@ -963,23 +958,25 @@ struct find_poinwise_reduce_reshape
return
a
.
get
(
"pointwise"
,
false
)
or
a
.
get
(
"reduce"
,
false
);
return
a
.
get
(
"pointwise"
,
false
)
or
a
.
get
(
"reduce"
,
false
);
}
}
static
std
::
vector
<
instruction_ref
>
topo_sort
(
instruction_ref
entry
,
const
std
::
unordered_set
<
instruction_ref
>&
inss
,
std
::
unordered_set
<
instruction_ref
>&
aux
)
static
std
::
vector
<
instruction_ref
>
topo_sort
(
instruction_ref
entry
,
const
std
::
unordered_set
<
instruction_ref
>&
inss
,
std
::
unordered_set
<
instruction_ref
>&
aux
)
{
{
std
::
vector
<
instruction_ref
>
instructions
;
std
::
vector
<
instruction_ref
>
instructions
;
bool
has_entry
=
contains
(
inss
,
entry
);
bool
has_entry
=
contains
(
inss
,
entry
);
fix
([
&
](
auto
self
,
instruction_ref
ins
)
{
fix
([
&
](
auto
self
,
instruction_ref
ins
)
{
if
(
ins
!=
entry
or
has_entry
)
if
(
ins
!=
entry
or
has_entry
)
instructions
.
push_back
(
ins
);
instructions
.
push_back
(
ins
);
for
(
auto
input
:
ins
->
inputs
())
for
(
auto
input
:
ins
->
inputs
())
{
{
if
(
not
contains
(
inss
,
input
))
if
(
not
contains
(
inss
,
input
))
aux
.
insert
(
input
);
aux
.
insert
(
input
);
}
}
for
(
auto
output
:
ins
->
outputs
())
for
(
auto
output
:
ins
->
outputs
())
{
{
if
(
contains
(
instructions
,
output
))
if
(
contains
(
instructions
,
output
))
continue
;
continue
;
if
(
not
contains
(
inss
,
output
))
if
(
not
contains
(
inss
,
output
))
continue
;
continue
;
self
(
output
);
self
(
output
);
}
}
...
@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape
...
@@ -988,23 +985,24 @@ struct find_poinwise_reduce_reshape
return
instructions
;
return
instructions
;
}
}
static
std
::
vector
<
instruction_ref
>
topo_sort
(
const
std
::
unordered_set
<
instruction_ref
>&
inss
,
std
::
unordered_set
<
instruction_ref
>&
aux
)
static
std
::
vector
<
instruction_ref
>
topo_sort
(
const
std
::
unordered_set
<
instruction_ref
>&
inss
,
std
::
unordered_set
<
instruction_ref
>&
aux
)
{
{
std
::
vector
<
instruction_ref
>
instructions
;
std
::
vector
<
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
visited
;
std
::
unordered_set
<
instruction_ref
>
visited
;
for
(
auto
ins
:
inss
)
for
(
auto
ins
:
inss
)
{
{
fix
([
&
](
auto
self
,
instruction_ref
child
)
{
fix
([
&
](
auto
self
,
instruction_ref
child
)
{
if
(
contains
(
visited
,
child
))
if
(
contains
(
visited
,
child
))
return
;
return
;
for
(
auto
output
:
child
->
outputs
())
for
(
auto
output
:
child
->
outputs
())
{
{
if
(
not
contains
(
inss
,
output
))
if
(
not
contains
(
inss
,
output
))
continue
;
continue
;
self
(
output
);
self
(
output
);
}
}
visited
.
insert
(
child
);
visited
.
insert
(
child
);
for
(
auto
input
:
child
->
inputs
())
for
(
auto
input
:
child
->
inputs
())
{
{
if
(
not
contains
(
inss
,
input
))
if
(
not
contains
(
inss
,
input
))
aux
.
insert
(
input
);
aux
.
insert
(
input
);
...
@@ -1025,11 +1023,11 @@ struct find_poinwise_reduce_reshape
...
@@ -1025,11 +1023,11 @@ struct find_poinwise_reduce_reshape
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
nelements
=
x_ins
->
get_shape
().
elements
();
auto
nelements
=
x_ins
->
get_shape
().
elements
();
auto
dims1
=
x_ins
->
get_shape
().
lens
();
auto
dims1
=
x_ins
->
get_shape
().
lens
();
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
auto
cd
=
common_dims
::
compute
(
dims1
,
dims2
);
auto
cd
=
common_dims
::
compute
(
dims1
,
dims2
);
if
(
cd
.
empty
())
if
(
cd
.
empty
())
return
;
return
;
// m.debug_print();
// m.debug_print();
...
@@ -1064,28 +1062,30 @@ struct find_poinwise_reduce_reshape
...
@@ -1064,28 +1062,30 @@ struct find_poinwise_reduce_reshape
// Collect from output
// Collect from output
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
// if(contains(inss, out))
// if(contains(inss, out))
// return;
// return;
// std::cout << "Visit: ";
// std::cout << "Visit: ";
// m.debug_print(out);
// m.debug_print(out);
// m.debug_print(out->inputs());
// m.debug_print(out->inputs());
auto
outputs
=
out
->
outputs
();
auto
outputs
=
out
->
outputs
();
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
instruction_ref
i
)
{
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
instruction_ref
i
)
{
return
std
::
distance
(
reshape_ins
,
i
);
return
std
::
distance
(
reshape_ins
,
i
);
}));
}));
// m.debug_print(outputs);
// m.debug_print(outputs);
for
(
auto
output
:
outputs
)
for
(
auto
output
:
outputs
)
{
{
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
return
input
->
can_eval
()
or
reshape_ins
==
input
or
contains
(
output_inss
,
input
);
// or dom.strictly_dominate(reshape_ins, input);
return
input
->
can_eval
()
or
reshape_ins
==
input
or
contains
(
output_inss
,
input
);
// or dom.strictly_dominate(reshape_ins, input);
}))
}))
continue
;
continue
;
if
(
not
is_pointwise_or_reduce
(
output
)
and
not
is_broadcast
(
output
))
if
(
not
is_pointwise_or_reduce
(
output
)
and
not
is_broadcast
(
output
))
continue
;
continue
;
if
(
is_reduce
(
output
))
if
(
is_reduce
(
output
))
{
{
auto
op_axes
=
output
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
auto
op_axes
=
output
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
if
(
axes
.
empty
())
if
(
axes
.
empty
())
axes
=
op_axes
;
axes
=
op_axes
;
if
(
axes
!=
op_axes
)
if
(
axes
!=
op_axes
)
return
;
return
;
...
@@ -1096,18 +1096,19 @@ struct find_poinwise_reduce_reshape
...
@@ -1096,18 +1096,19 @@ struct find_poinwise_reduce_reshape
})(
reshape_ins
);
})(
reshape_ins
);
std
::
vector
<
int64_t
>
common_axes
;
std
::
vector
<
int64_t
>
common_axes
;
for
(
auto
axis
:
axes
)
for
(
auto
axis
:
axes
)
{
{
common_axes
.
insert
(
common_axes
.
end
(),
cd
.
axes_map2
[
axis
].
begin
(),
cd
.
axes_map2
[
axis
].
end
());
common_axes
.
insert
(
common_axes
.
end
(),
cd
.
axes_map2
[
axis
].
begin
(),
cd
.
axes_map2
[
axis
].
end
());
}
}
auto
common_rdims
=
cd
.
dims
;
auto
common_rdims
=
cd
.
dims
;
for
(
auto
axis
:
common_axes
)
for
(
auto
axis
:
common_axes
)
{
{
common_rdims
[
axis
]
=
1
;
common_rdims
[
axis
]
=
1
;
}
}
// Topological sort
// Topological sort
std
::
unordered_set
<
instruction_ref
>
aux
;
std
::
unordered_set
<
instruction_ref
>
aux
;
auto
input_instructions
=
topo_sort
(
input_inss
,
aux
);
auto
input_instructions
=
topo_sort
(
input_inss
,
aux
);
auto
output_instructions
=
topo_sort
(
output_inss
,
aux
);
auto
output_instructions
=
topo_sort
(
output_inss
,
aux
);
// std::cout << "output_inss:\n";
// std::cout << "output_inss:\n";
// m.debug_print({output_inss.begin(), output_inss.end()});
// m.debug_print({output_inss.begin(), output_inss.end()});
...
@@ -1116,23 +1117,28 @@ struct find_poinwise_reduce_reshape
...
@@ -1116,23 +1117,28 @@ struct find_poinwise_reduce_reshape
// std::cout << "aux:\n";
// std::cout << "aux:\n";
// m.debug_print({aux.begin(), aux.end()});
// m.debug_print({aux.begin(), aux.end()});
auto
last
=
output_instructions
.
back
();
auto
last
=
output_instructions
.
back
();
auto
insert_reshape
=
[
&
](
instruction_ref
input
)
{
auto
insert_reshape
=
[
&
](
instruction_ref
input
)
{
auto
use_rdims
=
input
->
get_shape
().
elements
()
<
nelements
;
auto
use_rdims
=
input
->
get_shape
().
elements
()
<
nelements
;
auto
c
=
m
.
insert_instruction
(
last
,
make_op
(
"contiguous"
),
input
);
auto
c
=
m
.
insert_instruction
(
last
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
last
,
make_op
(
"reshape"
,
{{
"dims"
,
use_rdims
?
common_rdims
:
cd
.
dims
}}),
c
);
return
m
.
insert_instruction
(
last
,
make_op
(
"reshape"
,
{{
"dims"
,
use_rdims
?
common_rdims
:
cd
.
dims
}}),
c
);
};
};
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// map_ins[entry] = insert_reshape(entry);
// map_ins[entry] = insert_reshape(entry);
for
(
auto
i
:
aux
)
for
(
auto
i
:
aux
)
{
{
map_ins
[
i
]
=
insert_reshape
(
i
);
map_ins
[
i
]
=
insert_reshape
(
i
);
}
}
auto
inserter
=
[
&
](
module
&
mm
,
instruction_ref
i
,
operation
op
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
module_ref
>&
module_args
)
{
auto
inserter
=
[
&
](
module
&
mm
,
if
(
is_reduce
(
op
))
instruction_ref
i
,
operation
op
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
module_ref
>&
module_args
)
{
if
(
is_reduce
(
op
))
op
.
from_value
({{
"axes"
,
common_axes
}});
op
.
from_value
({{
"axes"
,
common_axes
}});
if
(
is_broadcast
(
op
))
if
(
is_broadcast
(
op
))
op
.
from_value
({{
"out_lens"
,
cd
.
dims
}});
op
.
from_value
({{
"out_lens"
,
cd
.
dims
}});
// std::cout << op << std::endl;
// std::cout << op << std::endl;
// m.debug_print(args);
// m.debug_print(args);
...
@@ -1141,7 +1147,7 @@ struct find_poinwise_reduce_reshape
...
@@ -1141,7 +1147,7 @@ struct find_poinwise_reduce_reshape
auto
new_x_ins
=
m
.
insert_instructions
(
inserter
,
last
,
input_instructions
,
map_ins
).
front
();
auto
new_x_ins
=
m
.
insert_instructions
(
inserter
,
last
,
input_instructions
,
map_ins
).
front
();
map_ins
[
reshape_ins
]
=
new_x_ins
;
map_ins
[
reshape_ins
]
=
new_x_ins
;
auto
new_last
=
m
.
insert_instructions
(
inserter
,
last
,
output_instructions
,
map_ins
).
front
();
auto
new_last
=
m
.
insert_instructions
(
inserter
,
last
,
output_instructions
,
map_ins
).
front
();
auto
new_c
=
m
.
insert_instruction
(
last
,
make_op
(
"contiguous"
),
new_last
);
auto
new_c
=
m
.
insert_instruction
(
last
,
make_op
(
"contiguous"
),
new_last
);
auto
new_reshape
=
m
.
insert_instruction
(
last
,
make_op
(
"reshape"
,
{{
"dims"
,
dims2
}}),
new_c
);
auto
new_reshape
=
m
.
insert_instruction
(
last
,
make_op
(
"reshape"
,
{{
"dims"
,
dims2
}}),
new_c
);
m
.
debug_print
();
m
.
debug_print
();
m
.
debug_print
(
last
);
m
.
debug_print
(
last
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment