Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7b39fb38
Commit
7b39fb38
authored
Aug 27, 2018
by
mei-ye
Browse files
staging
parent
d877a3fb
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
227 additions
and
185 deletions
+227
-185
src/include/migraph/memory_coloring.hpp
src/include/migraph/memory_coloring.hpp
+0
-2
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+7
-12
src/include/migraph/program.hpp
src/include/migraph/program.hpp
+1
-2
src/opt/common_header.hpp
src/opt/common_header.hpp
+2
-1
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+1
-2
src/opt/memory_coloring_impl.cpp
src/opt/memory_coloring_impl.cpp
+131
-101
src/opt/memory_coloring_impl.hpp
src/opt/memory_coloring_impl.hpp
+75
-48
src/targets/gpu/hip.cpp
src/targets/gpu/hip.cpp
+0
-2
src/targets/gpu/include/migraph/gpu/hip.hpp
src/targets/gpu/include/migraph/gpu/hip.hpp
+7
-11
src/targets/gpu/write_literals.cpp
src/targets/gpu/write_literals.cpp
+3
-4
No files found.
src/include/migraph/memory_coloring.hpp
View file @
7b39fb38
...
@@ -12,8 +12,6 @@ struct memory_coloring
...
@@ -12,8 +12,6 @@ struct memory_coloring
std
::
string
name
()
const
{
return
"memory coloring"
;
}
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/include/migraph/operators.hpp
View file @
7b39fb38
...
@@ -537,12 +537,10 @@ struct div : binary
...
@@ -537,12 +537,10 @@ struct div : binary
struct
get_mem_ptr
struct
get_mem_ptr
{
{
std
::
string
name
()
const
{
return
"get_mem_ptr:"
+
std
::
to_string
(
offset
);
}
std
::
string
name
()
const
{
return
"get_mem_ptr:"
+
std
::
to_string
(
offset
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
1
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
inputs
.
at
(
1
);
return
{
std
::
move
(
output_shape
),
args
.
at
(
0
).
data
()
+
offset
};
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
args
.
at
(
0
).
data
()
+
offset
};
}
}
std
::
size_t
offset
=
0
;
std
::
size_t
offset
=
0
;
};
};
...
@@ -550,11 +548,9 @@ struct get_mem_ptr
...
@@ -550,11 +548,9 @@ struct get_mem_ptr
struct
write_literal
struct
write_literal
{
{
std
::
string
name
()
const
{
return
"write_literal"
;
}
std
::
string
name
()
const
{
return
"write_literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
2
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
{
return
inputs
.
at
(
2
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
assert
(
false
);
assert
(
false
);
}
}
};
};
...
@@ -573,7 +569,6 @@ struct outline
...
@@ -573,7 +569,6 @@ struct outline
return
{
s
,
nullptr
};
return
{
s
,
nullptr
};
}
}
};
};
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/include/migraph/program.hpp
View file @
7b39fb38
...
@@ -102,7 +102,6 @@ struct program
...
@@ -102,7 +102,6 @@ struct program
private:
private:
std
::
unique_ptr
<
program_impl
>
impl
;
std
::
unique_ptr
<
program_impl
>
impl
;
};
};
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/opt/common_header.hpp
View file @
7b39fb38
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <migraph/instruction.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/manage_ptr.hpp>
#include <set>
#include <set>
#include <list>
#include <list>
...
...
src/opt/memory_coloring.cpp
View file @
7b39fb38
...
@@ -2,10 +2,9 @@
...
@@ -2,10 +2,9 @@
#include "memory_coloring_impl.hpp"
#include "memory_coloring_impl.hpp"
namespace
migraph
{
namespace
migraph
{
void
memory_coloring
::
apply
(
program
&
p
)
const
void
memory_coloring
::
apply
(
program
&
p
)
const
{
{
memory_coloring_impl
opt
(
&
p
);
memory_coloring_impl
opt
(
&
p
);
opt
.
run
();
opt
.
run
();
}
}
}
// namespace migraph
}
// namespace migraph
src/opt/memory_coloring_impl.cpp
View file @
7b39fb38
...
@@ -4,29 +4,28 @@ namespace migraph {
...
@@ -4,29 +4,28 @@ namespace migraph {
void
memory_coloring_impl
::
run
()
void
memory_coloring_impl
::
run
()
{
{
build
();
build
();
if
(
num_of_lives
!=
0
)
{
if
(
num_of_lives
!=
0
)
{
DEBUG
(
dump
(
"---Before memory coloring---"
));
DEBUG
(
dump
(
"---Before memory coloring---"
));
DEBUG
(
dump_program
());
DEBUG
(
dump_program
());
DEBUG
(
dump_intervals
());
DEBUG
(
dump_intervals
());
// Coloring
// Coloring
while
(
!
alloc_queue
.
empty
())
{
while
(
!
alloc_queue
.
empty
())
T_live_interval
*
interval
=
alloc_queue
.
top
();
{
interval_ptr
interval
=
alloc_queue
.
top
();
allocate
(
interval
);
allocate
(
interval
);
alloc_queue
.
pop
();
alloc_queue
.
pop
();
}
}
rewrite
();
rewrite
();
DEBUG
(
verify
());
DEBUG
(
verify
());
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
free
(
live_intervals
[
i
]);
}
}
}
}
}
bool
memory_coloring_impl
::
allocate
(
T_live_
interval
*
interval
)
bool
memory_coloring_impl
::
allocate
(
interval
_ptr
interval
)
{
{
shape
s
=
interval
->
result
;
shape
s
=
interval
->
result
;
std
::
size_t
size
=
s
.
bytes
();
std
::
size_t
size
=
s
.
bytes
();
if
(
size
==
0
)
if
(
size
==
0
)
return
false
;
return
false
;
std
::
size_t
element_size
=
size
/
s
.
elements
();
std
::
size_t
element_size
=
size
/
s
.
elements
();
T_live_range
&
segment
=
interval
->
segment
;
T_live_range
&
segment
=
interval
->
segment
;
...
@@ -35,19 +34,25 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
...
@@ -35,19 +34,25 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
std
::
unordered_map
<
long
long
,
T_live_range
*>
offset2Live
;
std
::
unordered_map
<
long
long
,
T_live_range
*>
offset2Live
;
offset2Live
.
clear
();
offset2Live
.
clear
();
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
auto
iter
=
vn_set
.
begin
(),
end
=
vn_set
.
end
();
iter
!=
end
;
++
iter
)
{
for
(
auto
iter
=
vn_set
.
begin
(),
end
=
vn_set
.
end
();
iter
!=
end
;
++
iter
)
{
T_live_range
*
range
=
live_ranges
[
*
iter
];
T_live_range
*
range
=
live_ranges
[
*
iter
];
long
long
offset
=
range
->
offset
;
long
long
offset
=
range
->
offset
;
if
(
offset
!=
InvalidOffset
)
{
if
(
offset
!=
InvalidOffset
)
{
conflict_queue
.
push
(
range
);
conflict_queue
.
push
(
range
);
if
(
offset2Live
.
find
(
offset
)
==
offset2Live
.
end
())
{
if
(
offset2Live
.
find
(
offset
)
==
offset2Live
.
end
())
{
offset2Live
[
offset
]
=
range
;
offset2Live
[
offset
]
=
range
;
}
else
{
}
else
{
T_live_range
*
prev
=
offset2Live
[
offset
];
T_live_range
*
prev
=
offset2Live
[
offset
];
assert
(
prev
->
offset
==
offset
);
assert
(
prev
->
offset
==
offset
);
if
(
prev
->
size
<
range
->
size
)
if
(
prev
->
size
<
range
->
size
)
offset2Live
[
offset
]
=
range
;
offset2Live
[
offset
]
=
range
;
}
}
}
}
...
@@ -55,15 +60,18 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
...
@@ -55,15 +60,18 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
}
}
long
long
offset
=
0
;
long
long
offset
=
0
;
while
(
!
conflict_queue
.
empty
())
{
while
(
!
conflict_queue
.
empty
())
{
T_live_range
*
range
=
conflict_queue
.
top
();
T_live_range
*
range
=
conflict_queue
.
top
();
long
long
cur_offset
=
range
->
offset
;
long
long
cur_offset
=
range
->
offset
;
if
(
offset2Live
[
cur_offset
]
==
range
)
{
if
(
offset2Live
[
cur_offset
]
==
range
)
if
((
cur_offset
>
offset
)
&&
(
cur_offset
-
offset
)
>=
size
)
{
{
if
((
cur_offset
>
offset
)
&&
(
cur_offset
-
offset
)
>=
size
)
{
break
;
break
;
}
}
offset
=
cur_offset
+
range
->
size
;
offset
=
cur_offset
+
range
->
size
;
if
((
offset
%
element_size
)
!=
0
)
if
((
offset
%
element_size
)
!=
0
)
offset
+=
(
element_size
-
(
offset
%
element_size
));
offset
+=
(
element_size
-
(
offset
%
element_size
));
}
}
conflict_queue
.
pop
();
conflict_queue
.
pop
();
...
@@ -77,7 +85,7 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
...
@@ -77,7 +85,7 @@ bool memory_coloring_impl::allocate(T_live_interval* interval)
void
memory_coloring_impl
::
build
()
void
memory_coloring_impl
::
build
()
{
{
int
num_of_instrs
=
p_program
->
get_size
();
int
num_of_instrs
=
p_program
->
get_size
();
if
(
num_of_instrs
==
0
)
if
(
num_of_instrs
==
0
)
return
;
return
;
int
cur_points
=
num_of_instrs
*
2
;
int
cur_points
=
num_of_instrs
*
2
;
instruction_ref
iter
=
std
::
prev
(
p_program
->
end
());
instruction_ref
iter
=
std
::
prev
(
p_program
->
end
());
...
@@ -85,45 +93,56 @@ void memory_coloring_impl::build()
...
@@ -85,45 +93,56 @@ void memory_coloring_impl::build()
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
vector
<
instruction_ref
>
dead_instrs
;
std
::
set
<
int
>
live_set
;
std
::
set
<
int
>
live_set
;
// Build live intervals.
// Build live intervals.
do
{
do
{
const
instruction
*
p_iter
=
&
(
*
iter
);
const
instruction
*
p_iter
=
&
(
*
iter
);
T_live_
interval
*
def_interval
=
nullptr
;
interval
_ptr
def_interval
=
nullptr
;
bool
isDead
=
false
;
bool
isDead
=
false
;
if
(
instr2Live
.
find
(
p_iter
)
!=
instr2Live
.
end
())
{
if
(
instr2Live
.
find
(
p_iter
)
!=
instr2Live
.
end
())
def_interval
=
instr2Live
[
p_iter
];
{
def_interval
=
std
::
move
(
instr2Live
[
p_iter
]);
bool
isLit
=
isLiteral
(
iter
);
bool
isLit
=
isLiteral
(
iter
);
if
(
isAllocate
(
iter
)
||
isLit
)
{
if
(
isAllocate
(
iter
)
||
isLit
)
{
T_live_range
&
range
=
def_interval
->
segment
;
T_live_range
&
range
=
def_interval
->
segment
;
def_interval
->
result
=
iter
->
result
;
def_interval
->
result
=
iter
->
result
;
def_interval
->
isLiteral
=
isLit
;
def_interval
->
isLiteral
=
isLit
;
alloc_queue
.
push
(
def_interval
);
alloc_queue
.
push
(
std
::
move
(
def_interval
)
)
;
range
.
begin
=
cur_points
;
range
.
begin
=
cur_points
;
range
.
size
=
(
iter
->
result
).
bytes
();
range
.
size
=
(
iter
->
result
).
bytes
();
live_set
.
erase
(
range
.
vn
);
live_set
.
erase
(
range
.
vn
);
}
}
}
else
if
(
!
isParam
(
iter
)
&&
!
isOutline
(
iter
)
&&
!
isCheckContext
(
iter
))
{
}
else
if
(
!
isParam
(
iter
)
&&
!
isOutline
(
iter
)
&&
!
isCheckContext
(
iter
))
{
isDead
=
true
;
isDead
=
true
;
}
}
int
tieNdx
=
getInputTieNdx
(
iter
);
int
tieNdx
=
getInputTieNdx
(
iter
);
if
(
!
iter
->
arguments
.
empty
())
{
if
(
!
iter
->
arguments
.
empty
())
{
int
cnt
=
-
1
;
int
cnt
=
-
1
;
for
(
auto
&&
arg
:
iter
->
arguments
)
{
for
(
auto
&&
arg
:
iter
->
arguments
)
{
cnt
++
;
cnt
++
;
if
(
isParam
(
arg
)
||
isOutline
(
arg
))
{
if
(
isParam
(
arg
)
||
isOutline
(
arg
))
if
(
isOutputParam
(
arg
))
{
if
(
isOutputParam
(
arg
))
isDead
=
false
;
isDead
=
false
;
continue
;
continue
;
}
}
const
instruction
*
p_arg
=
&
(
*
arg
);
const
instruction
*
p_arg
=
&
(
*
arg
);
if
(
cnt
==
tieNdx
)
{
if
(
cnt
==
tieNdx
)
{
// input memory is used as this instruction's output.
// input memory is used as this instruction's output.
// def is considered as use. Coalesce the live intervals.
// def is considered as use. Coalesce the live intervals.
def_interval
->
addUse
(
cur_points
);
def_interval
->
addUse
(
cur_points
);
instr2Live
[
p_arg
]
=
def_interval
;
instr2Live
[
p_arg
]
=
def_interval
;
}
else
if
(
instr2Live
.
find
(
p_arg
)
==
instr2Live
.
end
())
{
}
else
if
(
instr2Live
.
find
(
p_arg
)
==
instr2Live
.
end
())
{
// First time see a use, create a live interval.
// First time see a use, create a live interval.
int
id
=
num_of_lives
++
;
int
id
=
num_of_lives
++
;
T_live_
interval
*
interval
=
new
live_interval
();
interval
_ptr
interval
(
new
live_interval
()
)
;
interval
->
id
=
id
;
interval
->
id
=
id
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
end
=
cur_points
;
interval
->
segment
.
vn
=
++
max_value_number
;
interval
->
segment
.
vn
=
++
max_value_number
;
...
@@ -131,20 +150,22 @@ void memory_coloring_impl::build()
...
@@ -131,20 +150,22 @@ void memory_coloring_impl::build()
instr2Live
[
p_arg
]
=
interval
;
instr2Live
[
p_arg
]
=
interval
;
addConflicts
(
live_set
,
max_value_number
);
addConflicts
(
live_set
,
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_set
.
insert
(
max_value_number
);
live_intervals
[
id
]
=
interval
;
live_intervals
[
id
]
=
std
::
move
(
interval
)
;
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
live_ranges
[
max_value_number
]
=
&
(
interval
->
segment
);
}
else
{
}
T_live_interval
*
interval
=
instr2Live
[
p_arg
];
else
{
interval_ptr
interval
=
instr2Live
[
p_arg
];
interval
->
addUse
(
cur_points
);
interval
->
addUse
(
cur_points
);
DEBUG
(
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
()));
DEBUG
(
assert
(
live_set
.
find
(
interval
->
id
)
!=
live_set
.
end
()));
}
}
}
}
}
}
if
(
isDead
)
if
(
isDead
)
dead_instrs
.
push_back
(
iter
);
dead_instrs
.
push_back
(
iter
);
cur_points
-=
2
;
cur_points
-=
2
;
iter
=
std
::
prev
(
iter
);
iter
=
std
::
prev
(
iter
);
}
while
(
iter
!=
begin
);
}
while
(
iter
!=
begin
);
}
}
void
memory_coloring_impl
::
rewrite
()
void
memory_coloring_impl
::
rewrite
()
...
@@ -152,22 +173,29 @@ void memory_coloring_impl::rewrite()
...
@@ -152,22 +173,29 @@ void memory_coloring_impl::rewrite()
instruction_ref
end
=
p_program
->
end
();
instruction_ref
end
=
p_program
->
end
();
instruction_ref
scratch_param
=
end
;
instruction_ref
scratch_param
=
end
;
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
dims
.
push_back
(
required_bytes
/
sizeof
(
float
));
dims
.
push_back
(
required_bytes
/
sizeof
(
float
));
shape
s
=
{
shape
::
float_type
,
dims
};
shape
s
=
{
shape
::
float_type
,
dims
};
scratch_param
=
p_program
->
add_parameter
(
"scratch"
,
s
);
scratch_param
=
p_program
->
add_parameter
(
"scratch"
,
s
);
for
(
auto
ins
:
iterator_for
(
*
p_program
))
{
for
(
auto
ins
:
iterator_for
(
*
p_program
))
{
const
instruction
*
p_iter
=
&
(
*
ins
);
const
instruction
*
p_iter
=
&
(
*
ins
);
if
(
instr2Live
.
find
(
p_iter
)
!=
instr2Live
.
end
())
{
if
(
instr2Live
.
find
(
p_iter
)
!=
instr2Live
.
end
())
T_live_interval
*
interval
=
instr2Live
[
p_iter
];
{
if
(
interval
->
get_offset
()
==
InvalidOffset
)
{
interval_ptr
interval
=
instr2Live
[
p_iter
];
if
(
interval
->
get_offset
()
==
InvalidOffset
)
{
DEBUG
(
assert
((
interval
->
get_begin
()
==
InvalidOffset
)
||
DEBUG
(
assert
((
interval
->
get_begin
()
==
InvalidOffset
)
||
interval
->
result
.
bytes
()
==
0
));
interval
->
result
.
bytes
()
==
0
));
continue
;
continue
;
}
}
std
::
size_t
offset
=
interval
->
get_offset
();
std
::
size_t
offset
=
interval
->
get_offset
();
if
(
isAllocate
(
ins
))
{
if
(
isAllocate
(
ins
))
p_program
->
replace_instruction
(
ins
,
get_mem_ptr
{
offset
},
scratch_param
,
ins
->
arguments
.
at
(
0
));
{
}
else
if
(
isLiteral
(
ins
))
{
p_program
->
replace_instruction
(
ins
,
get_mem_ptr
{
offset
},
scratch_param
,
ins
->
arguments
.
at
(
0
));
}
else
if
(
isLiteral
(
ins
))
{
auto
pre
=
p_program
->
add_literal
(
ins
->
lit
);
auto
pre
=
p_program
->
add_literal
(
ins
->
lit
);
auto
index
=
p_program
->
add_literal
(
offset
);
auto
index
=
p_program
->
add_literal
(
offset
);
p_program
->
replace_instruction
(
ins
,
write_literal
{},
scratch_param
,
index
,
pre
);
p_program
->
replace_instruction
(
ins
,
write_literal
{},
scratch_param
,
index
,
pre
);
...
@@ -179,31 +207,28 @@ void memory_coloring_impl::rewrite()
...
@@ -179,31 +207,28 @@ void memory_coloring_impl::rewrite()
}
}
#ifdef DEBUG_OPT
#ifdef DEBUG_OPT
void
memory_coloring_impl
::
dump
(
std
::
string
str
)
void
memory_coloring_impl
::
dump
(
std
::
string
str
)
{
std
::
cout
<<
str
<<
std
::
endl
;
}
{
std
::
cout
<<
str
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_program
()
void
memory_coloring_impl
::
dump_program
()
{
std
::
cout
<<
*
p_program
<<
std
::
endl
;
}
{
std
::
cout
<<
*
p_program
<<
std
::
endl
;
}
void
memory_coloring_impl
::
dump_intervals
()
void
memory_coloring_impl
::
dump_intervals
()
{
{
if
(
num_of_lives
>
0
)
{
if
(
num_of_lives
>
0
)
{
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
std
::
cout
<<
"---live intervals ---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
T_live_interval
*
interval
=
live_intervals
[
i
];
{
interval_ptr
interval
=
live_intervals
[
i
];
interval
->
dump
();
interval
->
dump
();
}
}
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
std
::
cout
<<
"---conflict table---"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
for
(
int
i
=
0
;
i
<=
max_value_number
;
++
i
)
{
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" segment:"
<<
i
;
std
::
cout
<<
" =>"
;
std
::
cout
<<
" =>"
;
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
std
::
set
<
int
>&
table
=
conflict_table
[
i
];
for
(
auto
iter
=
table
.
begin
(),
end
=
table
.
end
();
iter
!=
end
;
++
iter
)
{
for
(
auto
iter
=
table
.
begin
(),
end
=
table
.
end
();
iter
!=
end
;
++
iter
)
{
std
::
cout
<<
(
*
iter
)
<<
","
;
std
::
cout
<<
(
*
iter
)
<<
","
;
}
}
}
}
...
@@ -213,20 +238,24 @@ void memory_coloring_impl::dump_intervals()
...
@@ -213,20 +238,24 @@ void memory_coloring_impl::dump_intervals()
void
memory_coloring_impl
::
verify
()
void
memory_coloring_impl
::
verify
()
{
{
if
(
num_of_lives
>
0
)
{
if
(
num_of_lives
>
0
)
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
{
T_live_interval
*
interval
=
live_intervals
[
i
];
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
{
interval_ptr
interval
=
live_intervals
[
i
];
T_live_range
&
segment
=
interval
->
segment
;
T_live_range
&
segment
=
interval
->
segment
;
if
(
segment
.
offset
==
InvalidOffset
)
if
(
segment
.
offset
==
InvalidOffset
)
continue
;
continue
;
int
vn
=
segment
.
vn
;
int
vn
=
segment
.
vn
;
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
if
(
conflict_table
.
find
(
vn
)
!=
conflict_table
.
end
())
{
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
std
::
set
<
int
>&
vn_set
=
conflict_table
[
vn
];
for
(
auto
iter
=
vn_set
.
begin
(),
end
=
vn_set
.
end
();
iter
!=
end
;
++
iter
)
{
for
(
auto
iter
=
vn_set
.
begin
(),
end
=
vn_set
.
end
();
iter
!=
end
;
++
iter
)
{
T_live_range
*
range
=
live_ranges
[
*
iter
];
T_live_range
*
range
=
live_ranges
[
*
iter
];
if
(
range
->
offset
==
InvalidOffset
)
if
(
range
->
offset
==
InvalidOffset
)
continue
;
continue
;
if
(
!
isDisjoin
(
*
range
,
segment
))
if
(
!
isDisjoin
(
*
range
,
segment
))
assert
(
false
);
assert
(
false
);
}
}
}
}
...
@@ -240,7 +269,8 @@ void live_range::dump()
...
@@ -240,7 +269,8 @@ void live_range::dump()
{
{
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" segment:"
<<
vn
;
std
::
cout
<<
" ["
<<
GET_INS_ENUM
(
begin
)
<<
", "
<<
GET_INS_ENUM
(
end
)
<<
"]"
;
std
::
cout
<<
" ["
<<
GET_INS_ENUM
(
begin
)
<<
", "
<<
GET_INS_ENUM
(
end
)
<<
"]"
;
if
(
offset
!=
InvalidOffset
)
{
if
(
offset
!=
InvalidOffset
)
{
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" mem:"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
std
::
cout
<<
" ["
<<
offset
<<
","
<<
offset
+
size
-
1
<<
"]"
;
}
}
...
@@ -252,17 +282,17 @@ void live_interval::dump()
...
@@ -252,17 +282,17 @@ void live_interval::dump()
std
::
cout
<<
"id:"
<<
id
;
std
::
cout
<<
"id:"
<<
id
;
segment
.
dump
();
segment
.
dump
();
std
::
cout
<<
" uses:"
;
std
::
cout
<<
" uses:"
;
for
(
auto
iter
=
use_points
.
begin
(),
end
=
use_points
.
end
();
iter
!=
end
;
++
iter
)
{
for
(
auto
iter
=
use_points
.
begin
(),
end
=
use_points
.
end
();
iter
!=
end
;
++
iter
)
{
int
&
use
=
*
iter
;
int
&
use
=
*
iter
;
std
::
cout
<<
" "
<<
GET_INS_ENUM
(
use
)
<<
","
;
std
::
cout
<<
" "
<<
GET_INS_ENUM
(
use
)
<<
","
;
}
}
if
(
isLiteral
)
if
(
isLiteral
)
std
::
cout
<<
" literal"
;
std
::
cout
<<
" literal"
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
" "
<<
result
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
}
#endif
#endif
}
// namespace migraph
}
// namespace migraph
src/opt/memory_coloring_impl.hpp
View file @
7b39fb38
...
@@ -6,7 +6,9 @@ namespace migraph {
...
@@ -6,7 +6,9 @@ namespace migraph {
#define InvalidOffset -1
#define InvalidOffset -1
typedef
struct
live_range
{
typedef
struct
live_range
{
int
begin
;
// begin point in the instruction stream.
int
begin
;
// begin point in the instruction stream.
int
end
;
// end point in the instruction stream.
int
end
;
// end point in the instruction stream.
long
long
offset
;
// offset to base pointer of allocated memory trunk.
long
long
offset
;
// offset to base pointer of allocated memory trunk.
...
@@ -17,12 +19,15 @@ typedef struct live_range {
...
@@ -17,12 +19,15 @@ typedef struct live_range {
#endif
#endif
}
T_live_range
;
}
T_live_range
;
typedef
struct
live_interval
{
typedef
struct
live_interval
explicit
live_interval
()
{
init
();
}
{
live_interval
()
{
init
();
}
void
init
()
{
~
live_interval
()
{}
id
=
-
1
;
isLiteral
=
false
;
void
init
()
segment
=
{
-
1
,
-
1
,
InvalidOffset
,
-
1
,
0
};
{
id
=
-
1
;
isLiteral
=
false
;
segment
=
{
-
1
,
-
1
,
InvalidOffset
,
-
1
,
0
};
}
}
void
addUse
(
int
use
)
{
use_points
.
push_front
(
use
);
}
void
addUse
(
int
use
)
{
use_points
.
push_front
(
use
);
}
int
get_begin
()
const
{
return
segment
.
begin
;
}
int
get_begin
()
const
{
return
segment
.
begin
;
}
...
@@ -41,9 +46,21 @@ typedef struct live_interval {
...
@@ -41,9 +46,21 @@ typedef struct live_interval {
}
T_live_interval
;
}
T_live_interval
;
struct
memory_coloring_impl
{
// #define unique_interval_ptr std::unique_ptr<T_live_interval>
explicit
memory_coloring_impl
(
program
*
p
)
:
p_program
(
p
)
#define interval_ptr T_live_interval*
struct
memory_coloring_impl
{
memory_coloring_impl
(){}
memory_coloring_impl
(
program
*
p
)
:
p_program
(
p
)
{
{
init
();
}
~
memory_coloring_impl
()
{
for
(
int
i
=
0
;
i
<
num_of_lives
;
++
i
)
free
(
live_intervals
[
i
]);
}
void
init
()
{
instr2Live
.
clear
();
instr2Live
.
clear
();
live_intervals
.
clear
();
live_intervals
.
clear
();
live_ranges
.
clear
();
live_ranges
.
clear
();
...
@@ -52,10 +69,11 @@ struct memory_coloring_impl {
...
@@ -52,10 +69,11 @@ struct memory_coloring_impl {
max_value_number
=
-
1
;
max_value_number
=
-
1
;
required_bytes
=
0
;
required_bytes
=
0
;
}
}
bool
allocate
(
T_live_
interval
*
);
bool
allocate
(
interval
_ptr
);
void
addConflicts
(
std
::
set
<
int
>&
live_set
,
int
val
)
void
addConflicts
(
std
::
set
<
int
>&
live_set
,
int
val
)
{
{
for
(
auto
iter
=
live_set
.
begin
(),
end
=
live_set
.
end
();
iter
!=
end
;
++
iter
)
{
for
(
auto
iter
=
live_set
.
begin
(),
end
=
live_set
.
end
();
iter
!=
end
;
++
iter
)
{
conflict_table
[
*
iter
].
insert
(
val
);
conflict_table
[
*
iter
].
insert
(
val
);
conflict_table
[
val
].
insert
(
*
iter
);
conflict_table
[
val
].
insert
(
*
iter
);
}
}
...
@@ -63,6 +81,7 @@ struct memory_coloring_impl {
...
@@ -63,6 +81,7 @@ struct memory_coloring_impl {
void
build
();
void
build
();
void
run
();
void
run
();
void
rewrite
();
void
rewrite
();
private:
private:
bool
isParam
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"@param"
;
}
bool
isParam
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"@param"
;
}
bool
isOutputParam
(
const
instruction_ref
ins
)
bool
isOutputParam
(
const
instruction_ref
ins
)
...
@@ -78,19 +97,22 @@ struct memory_coloring_impl {
...
@@ -78,19 +97,22 @@ struct memory_coloring_impl {
bool
isLiteral
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"@literal"
;
}
bool
isLiteral
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"@literal"
;
}
bool
isCheckContext
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"check_context"
;
}
bool
isCheckContext
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"check_context"
;
}
bool
isTranspose
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"transpose"
;
}
bool
isTranspose
(
const
instruction_ref
ins
)
{
return
ins
->
op
.
name
()
==
"transpose"
;
}
int
getInputTieNdx
(
const
instruction_ref
ins
)
{
int
getInputTieNdx
(
const
instruction_ref
ins
)
if
(
isTranspose
(
ins
))
{
if
(
isTranspose
(
ins
))
return
0
;
return
0
;
int
cnt
=
-
1
;
int
cnt
=
-
1
;
int
last_allocate
=
-
1
;
int
last_allocate
=
-
1
;
for
(
auto
&&
arg
:
ins
->
arguments
)
{
for
(
auto
&&
arg
:
ins
->
arguments
)
{
cnt
++
;
cnt
++
;
if
(
isAllocate
(
arg
))
if
(
isAllocate
(
arg
))
last_allocate
=
cnt
;
last_allocate
=
cnt
;
}
}
return
last_allocate
;
return
last_allocate
;
}
}
bool
isDisjoin
(
T_live_range
&
range1
,
T_live_range
&
range2
)
{
bool
isDisjoin
(
T_live_range
&
range1
,
T_live_range
&
range2
)
{
long
long
end1
=
range1
.
offset
+
range1
.
size
-
1
;
long
long
end1
=
range1
.
offset
+
range1
.
size
-
1
;
long
long
end2
=
range2
.
offset
+
range2
.
size
-
1
;
long
long
end2
=
range2
.
offset
+
range2
.
size
-
1
;
return
((
end1
<
range2
.
offset
)
||
(
end2
<
range1
.
offset
));
return
((
end1
<
range2
.
offset
)
||
(
end2
<
range1
.
offset
));
...
@@ -102,39 +124,44 @@ struct memory_coloring_impl {
...
@@ -102,39 +124,44 @@ struct memory_coloring_impl {
void
dump_intervals
();
void
dump_intervals
();
void
verify
();
void
verify
();
#endif
#endif
struct
ordering
{
struct
ordering
bool
operator
()
(
const
T_live_interval
*
I1
,
const
T_live_interval
*
I2
)
const
{
bool
operator
()(
const
interval_ptr
I1
,
const
interval_ptr
I2
)
const
{
{
int
len1
=
I1
->
get_end
()
-
I1
->
get_begin
();
int
len1
=
I1
->
get_end
()
-
I1
->
get_begin
();
int
len2
=
I2
->
get_end
()
-
I2
->
get_begin
();
int
len2
=
I2
->
get_end
()
-
I2
->
get_begin
();
if
(
len1
!=
len2
)
{
if
(
len1
!=
len2
)
{
return
(
len1
<
len2
)
?
true
:
false
;
return
(
len1
<
len2
)
?
true
:
false
;
}
else
if
(
I1
->
result
.
bytes
()
!=
I2
->
result
.
bytes
())
{
}
else
if
(
I1
->
result
.
bytes
()
!=
I2
->
result
.
bytes
())
{
return
(
I1
->
result
.
bytes
()
<
I2
->
result
.
bytes
())
?
true
:
false
;
return
(
I1
->
result
.
bytes
()
<
I2
->
result
.
bytes
())
?
true
:
false
;
}
else
{
}
else
{
return
I1
->
id
>
I2
->
id
;
return
I1
->
id
>
I2
->
id
;
}
}
}
}
bool
operator
()
(
const
T_live_range
*
I1
,
const
T_live_range
*
I2
)
const
bool
operator
()(
const
T_live_range
*
I1
,
const
T_live_range
*
I2
)
const
{
{
return
(
I1
->
offset
>
I2
->
offset
);
return
(
I1
->
offset
>
I2
->
offset
);
}
}
};
};
program
*
p_program
;
program
*
p_program
;
std
::
unordered_map
<
const
instruction
*
,
T_live_
interval
*
>
instr2Live
;
std
::
unordered_map
<
const
instruction
*
,
interval
_ptr
>
instr2Live
;
// Map live interval Id to live interval.
// Map live interval Id to live interval.
std
::
unordered_map
<
int
,
T_live_
interval
*
>
live_intervals
;
std
::
unordered_map
<
int
,
interval
_ptr
>
live_intervals
;
// Map live range value number to live range.
// Map live range value number to live range.
std
::
unordered_map
<
int
,
T_live_range
*>
live_ranges
;
std
::
unordered_map
<
int
,
T_live_range
*>
live_ranges
;
// Map live range value number to a set of conflicting live ranges' value numbers.
// Map live range value number to a set of conflicting live ranges' value numbers.
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
;
std
::
unordered_map
<
int
,
std
::
set
<
int
>>
conflict_table
;
// Priority queue for coloring.
// Priority queue for coloring.
std
::
priority_queue
<
T_live_
interval
*
,
std
::
vector
<
T_live_
interval
*
>
,
ordering
>
alloc_queue
;
std
::
priority_queue
<
interval
_ptr
,
std
::
vector
<
interval
_ptr
>
,
ordering
>
alloc_queue
;
int
num_of_lives
;
int
num_of_lives
;
int
max_value_number
;
int
max_value_number
;
long
long
required_bytes
;
long
long
required_bytes
;
};
};
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/targets/gpu/hip.cpp
View file @
7b39fb38
...
@@ -94,7 +94,5 @@ void copy_to_gpu(char* dst, const char* src, std::size_t size)
...
@@ -94,7 +94,5 @@ void copy_to_gpu(char* dst, const char* src, std::size_t size)
{
{
hipMemcpy
(
dst
,
src
,
size
,
hipMemcpyHostToDevice
);
hipMemcpy
(
dst
,
src
,
size
,
hipMemcpyHostToDevice
);
}
}
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
src/targets/gpu/include/migraph/gpu/hip.hpp
View file @
7b39fb38
...
@@ -85,12 +85,10 @@ struct hip_write
...
@@ -85,12 +85,10 @@ struct hip_write
struct
hip_memcpy
struct
hip_memcpy
{
{
std
::
string
name
()
const
{
return
"hip_memcpy"
;
}
std
::
string
name
()
const
{
return
"hip_memcpy"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
inputs
.
at
(
2
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
return
inputs
.
at
(
2
);
std
::
size_t
*
p_data
=
reinterpret_cast
<
std
::
size_t
*>
(
args
.
at
(
1
).
data
());
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
std
::
size_t
*
p_data
=
reinterpret_cast
<
std
::
size_t
*>
(
args
.
at
(
1
).
data
());
char
*
dst
=
args
.
at
(
0
).
data
()
+
p_data
[
0
];
char
*
dst
=
args
.
at
(
0
).
data
()
+
p_data
[
0
];
const
char
*
src
=
args
.
at
(
2
).
data
();
const
char
*
src
=
args
.
at
(
2
).
data
();
std
::
size_t
size
=
args
.
at
(
2
).
get_shape
().
bytes
();
std
::
size_t
size
=
args
.
at
(
2
).
get_shape
().
bytes
();
...
@@ -98,9 +96,7 @@ struct hip_memcpy
...
@@ -98,9 +96,7 @@ struct hip_memcpy
return
{
output_shape
,
dst
};
return
{
output_shape
,
dst
};
}
}
};
};
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
#endif
#endif
src/targets/gpu/write_literals.cpp
View file @
7b39fb38
...
@@ -37,13 +37,12 @@ void write_literals::apply(program& p) const
...
@@ -37,13 +37,12 @@ void write_literals::apply(program& p) const
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
p.replace_instruction(ins, hip_load_literal{a.get_shape(), n});
}
}
#else
#else
if
(
ins
->
op
.
name
()
==
"write_literal"
)
{
if
(
ins
->
op
.
name
()
==
"write_literal"
)
{
p
.
replace_instruction
(
ins
,
hip_memcpy
{},
ins
->
arguments
);
p
.
replace_instruction
(
ins
,
hip_memcpy
{},
ins
->
arguments
);
}
}
#endif
#endif
}
}
}
}
}
// namespace gpu
}
// namespace gpu
}
// namespace migraph
}
// namespace migraph
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment