Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c174704f
Commit
c174704f
authored
Oct 30, 2018
by
Scott Thornton
Browse files
Formatting
parent
a2092da6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
64 deletions
+67
-64
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+17
-19
src/include/migraph/eliminate_concat.hpp
src/include/migraph/eliminate_concat.hpp
+1
-1
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+1
-4
test/eliminate_concat_test.cpp
test/eliminate_concat_test.cpp
+47
-39
tools/include/concat_opt.hpp
tools/include/concat_opt.hpp
+1
-1
No files found.
src/eliminate_concat.cpp
View file @
c174704f
...
@@ -19,50 +19,48 @@ void eliminate_concat::apply(program& p) const
...
@@ -19,50 +19,48 @@ void eliminate_concat::apply(program& p) const
return
arg
->
name
()
==
"@literal"
;
return
arg
->
name
()
==
"@literal"
;
}))
}))
continue
;
continue
;
// We can only do this optimization when concat axis is either the leftmost
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
// we only need to check the first input
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
if
(
concat_op
.
axis
==
0
||
if
(
concat_op
.
axis
==
0
||
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
concat_op
.
axis
,
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
concat_op
.
axis
,
[](
auto
x
)
{
return
x
==
1
;
}))
[]
(
auto
x
)
{
return
x
==
1
;
}))
{
{
// Last input should be an allocation
// Last input should be an allocation
auto
last
=
ins
->
inputs
().
back
();
auto
last
=
ins
->
inputs
().
back
();
if
(
last
->
name
()
!=
concat_opt
.
allocate
())
continue
;
if
(
last
->
name
()
!=
concat_opt
.
allocate
())
continue
;
// Where are the allocations for the tensors to be concatenated?
// Where are the allocations for the tensors to be concatenated?
std
::
vector
<
instruction_ref
>
allocations
;
std
::
vector
<
instruction_ref
>
allocations
;
for
(
auto
ins2
=
ins
->
inputs
().
begin
();
ins2
!=
ins
->
inputs
().
end
()
-
1
;
ins2
++
)
for
(
auto
ins2
=
ins
->
inputs
().
begin
();
ins2
!=
ins
->
inputs
().
end
()
-
1
;
ins2
++
)
{
{
auto
last2
=
(
*
ins2
)
->
inputs
().
back
();
auto
last2
=
(
*
ins2
)
->
inputs
().
back
();
if
(
last2
->
name
()
==
concat_opt
.
allocate
())
if
(
last2
->
name
()
==
concat_opt
.
allocate
())
{
{
allocations
.
push_back
(
last2
);
allocations
.
push_back
(
last2
);
}
}
}
}
// Need to sort the allocations, so that we know where to
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
// insert the "super"-allocation
std
::
sort
(
allocations
.
begin
(),
allocations
.
end
(),
[
&
]
(
instruction_ref
x
,
instruction_ref
y
)
{
std
::
sort
(
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
allocations
.
begin
(),
allocations
.
end
(),
[
&
](
instruction_ref
x
,
instruction_ref
y
)
{
});
return
std
::
distance
(
p
.
begin
(),
x
)
<
std
::
distance
(
p
.
begin
(),
y
);
});
// Move "super" allocation to the front
// Move "super" allocation to the front
auto
first
=
allocations
.
front
();
auto
first
=
allocations
.
front
();
auto
super
=
p
.
move_instruction
(
last
,
first
);
auto
super
=
p
.
move_instruction
(
last
,
first
);
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
for
(
auto
x
:
allocations
)
for
(
auto
x
:
allocations
)
{
{
migraph
::
op
::
load
op
{
x
->
get_shape
(),
offset
};
migraph
::
op
::
load
op
{
x
->
get_shape
(),
offset
};
p
.
replace_instruction
(
x
,
op
,
{
super
});
p
.
replace_instruction
(
x
,
op
,
{
super
});
offset
+=
x
->
get_shape
().
elements
();
offset
+=
x
->
get_shape
().
elements
();
}
}
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
vector
<
instruction_ref
>
args
=
{
super
};
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
copy
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
()
-
1
,
std
::
back_inserter
(
args
));
std
::
back_inserter
(
args
));
p
.
replace_instruction
(
ins
,
migraph
::
op
::
identity
{},
args
);
p
.
replace_instruction
(
ins
,
migraph
::
op
::
identity
{},
args
);
}
}
}
}
...
...
src/include/migraph/eliminate_concat.hpp
View file @
c174704f
...
@@ -11,7 +11,7 @@ struct program;
...
@@ -11,7 +11,7 @@ struct program;
struct
eliminate_concat
struct
eliminate_concat
{
{
concat_optimization
concat_opt
;
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/include/migraph/operators.hpp
View file @
c174704f
...
@@ -620,10 +620,7 @@ struct unary
...
@@ -620,10 +620,7 @@ struct unary
struct
identity
struct
identity
{
{
std
::
string
name
()
const
{
return
"identity"
;
}
std
::
string
name
()
const
{
return
"identity"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
0
);
}
{
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
...
...
test/eliminate_concat_test.cpp
View file @
c174704f
...
@@ -6,35 +6,27 @@
...
@@ -6,35 +6,27 @@
struct
concat
struct
concat
{
{
concat
(
std
::
size_t
axis
)
concat
(
std
::
size_t
axis
)
{
op
.
axis
=
axis
;
}
{
op
.
axis
=
axis
;
}
migraph
::
op
::
concat
op
;
migraph
::
op
::
concat
op
;
std
::
string
name
()
const
{
return
"eliminate_concat::concat"
;
}
std
::
string
name
()
const
{
return
"eliminate_concat::concat"
;
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
inputs
)
const
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
inputs
)
const
{
{
return
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
);
}
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
ctx
,
compute
(
migraph
::
context
&
ctx
,
const
migraph
::
shape
&
output_shape
,
const
std
::
vector
<
migraph
::
argument
>&
args
)
const
const
migraph
::
shape
&
output_shape
,
const
std
::
vector
<
migraph
::
argument
>&
args
)
const
{
{
return
{
output_shape
};
return
{
output_shape
};
}
}
};
};
struct
concat_test_optimization
struct
concat_test_optimization
{
{
/// A unique name used to identify the concat optimization
/// A unique name used to identify the concat optimization
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"eliminate_concat::concat"
;
}
{
return
"eliminate_concat::concat"
;
}
/// A unique name used to identify the allocate operator
/// A unique name used to identify the allocate operator
std
::
string
allocate
()
const
std
::
string
allocate
()
const
{
return
"allocate"
;
}
{
return
"allocate"
;
}
/// Return the lowered concat operator
/// Return the lowered concat operator
migraph
::
op
::
concat
get_concat
(
const
migraph
::
operation
&
op
)
const
migraph
::
op
::
concat
get_concat
(
const
migraph
::
operation
&
op
)
const
{
{
...
@@ -48,7 +40,8 @@ struct eliminate_concat_target
...
@@ -48,7 +40,8 @@ struct eliminate_concat_target
std
::
string
name
()
const
{
return
"eliminate_target"
;
}
std
::
string
name
()
const
{
return
"eliminate_target"
;
}
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
std
::
vector
<
migraph
::
pass
>
get_passes
(
migraph
::
context
&
)
const
{
{
return
{
migraph
::
eliminate_concat
{
concat_test_optimization
{}},
migraph
::
dead_code_elimination
{}};
return
{
migraph
::
eliminate_concat
{
concat_test_optimization
{}},
migraph
::
dead_code_elimination
{}};
}
}
migraph
::
context
get_context
()
const
{
return
{};
}
migraph
::
context
get_context
()
const
{
return
{};
}
};
};
...
@@ -84,32 +77,39 @@ struct fred_op
...
@@ -84,32 +77,39 @@ struct fred_op
{
{
return
args
.
at
(
0
);
return
args
.
at
(
0
);
}
}
};
};
void
basic
()
void
basic
()
{
{
auto
create_test_program
=
[]()
{
auto
create_test_program
=
[]()
{
migraph
::
program
p
;
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
8
,
8
}}});
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
8
,
8
}}});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}}});
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}}});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
a3
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
8
,
8
}}});
auto
a3
=
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
8
,
8
}}});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
std
::
size_t
axis
=
1
;
std
::
size_t
axis
=
1
;
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
10
,
8
,
8
}}});
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
10
,
8
,
8
}}});
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
return
p
;
return
p
;
};
};
auto
create_control_program
=
[]()
{
auto
create_control_program
=
[]()
{
migraph
::
program
p
;
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
10
,
8
,
8
}}});
auto
a1
=
auto
l1
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
8
,
8
}},
0
},
{
a1
});
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
10
,
8
,
8
}}});
auto
l1
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
8
,
8
}},
0
},
{
a1
});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
l1
);
auto
p1
=
p
.
add_instruction
(
fred_op
{},
l1
);
auto
l2
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}},
128
},
{
a1
});
auto
l2
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
8
,
8
}},
128
},
{
a1
});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
l2
);
auto
p2
=
p
.
add_instruction
(
fred_op
{},
l2
);
auto
l3
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
8
,
8
}},
320
},
{
a1
});
auto
l3
=
p
.
add_instruction
(
migraph
::
op
::
load
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
8
,
8
}},
320
},
{
a1
});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
l3
);
auto
p3
=
p
.
add_instruction
(
fred_op
{},
l3
);
auto
i1
=
p
.
add_instruction
(
migraph
::
op
::
identity
{},
{
a1
,
p1
,
p2
,
p3
});
auto
i1
=
p
.
add_instruction
(
migraph
::
op
::
identity
{},
{
a1
,
p1
,
p2
,
p3
});
return
p
;
return
p
;
...
@@ -126,29 +126,37 @@ void wont_work()
...
@@ -126,29 +126,37 @@ void wont_work()
{
{
auto
create_test_program
=
[]()
{
auto
create_test_program
=
[]()
{
migraph
::
program
p
;
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
8
,
8
}}});
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
8
,
8
}}});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
8
,
8
}}});
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
8
,
8
}}});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
a3
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
5
,
8
,
8
}}});
auto
a3
=
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
5
,
8
,
8
}}});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
std
::
size_t
axis
=
1
;
std
::
size_t
axis
=
1
;
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
10
,
8
,
8
}}});
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
10
,
8
,
8
}}});
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
return
p
;
return
p
;
};
};
auto
create_control_program
=
[]()
{
auto
create_control_program
=
[]()
{
migraph
::
program
p
;
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
8
,
8
}}});
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
8
,
8
}}});
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
p1
=
p
.
add_instruction
(
fred_op
{},
a1
);
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
8
,
8
}}});
auto
a2
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
3
,
8
,
8
}}});
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
p2
=
p
.
add_instruction
(
fred_op
{},
a2
);
auto
a3
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
5
,
8
,
8
}}});
auto
a3
=
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
5
,
8
,
8
}}});
auto
p3
=
p
.
add_instruction
(
fred_op
{},
a3
);
std
::
size_t
axis
=
1
;
std
::
size_t
axis
=
1
;
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
10
,
8
,
8
}}});
auto
a4
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
2
,
10
,
8
,
8
}}});
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
auto
p4
=
p
.
add_instruction
(
concat
(
axis
),
p1
,
p2
,
p3
,
a4
);
return
p
;
return
p
;
};
};
auto
p1
=
create_test_program
();
auto
p1
=
create_test_program
();
...
...
tools/include/concat_opt.hpp
View file @
c174704f
...
@@ -18,7 +18,7 @@ struct program;
...
@@ -18,7 +18,7 @@ struct program;
#ifdef DOXYGEN
#ifdef DOXYGEN
/// An interface for applying an optimization for the concat instruction
/// An interface for applying an optimization for the concat instruction
struct
concat_optimization
struct
concat_optimization
{
{
/// A unique name used to identify the concat optimization
/// A unique name used to identify the concat optimization
std
::
string
name
()
const
;
std
::
string
name
()
const
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment