Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11e155c2
Commit
11e155c2
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
8a9c5bce
aa7ff911
Changes
397
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
200 additions
and
68 deletions
+200
-68
src/include/migraphx/check_context.hpp
src/include/migraphx/check_context.hpp
+1
-1
src/include/migraphx/compile_src.hpp
src/include/migraphx/compile_src.hpp
+2
-0
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+15
-11
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+50
-11
src/include/migraphx/cpp_generator.hpp
src/include/migraphx/cpp_generator.hpp
+2
-0
src/include/migraphx/eliminate_allocation.hpp
src/include/migraphx/eliminate_allocation.hpp
+1
-1
src/include/migraphx/eliminate_common_subexpression.hpp
src/include/migraphx/eliminate_common_subexpression.hpp
+1
-1
src/include/migraphx/eliminate_concat.hpp
src/include/migraphx/eliminate_concat.hpp
+1
-1
src/include/migraphx/eliminate_contiguous.hpp
src/include/migraphx/eliminate_contiguous.hpp
+1
-1
src/include/migraphx/eliminate_identity.hpp
src/include/migraphx/eliminate_identity.hpp
+1
-1
src/include/migraphx/filesystem.hpp
src/include/migraphx/filesystem.hpp
+4
-1
src/include/migraphx/gemm.hpp
src/include/migraphx/gemm.hpp
+4
-2
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+4
-4
src/include/migraphx/json.hpp
src/include/migraphx/json.hpp
+1
-0
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+13
-1
src/include/migraphx/marker.hpp
src/include/migraphx/marker.hpp
+17
-12
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+81
-17
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/op/as_shape.hpp
src/include/migraphx/op/as_shape.hpp
+0
-1
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+0
-1
No files found.
src/include/migraphx/check_context.hpp
View file @
11e155c2
...
...
@@ -33,7 +33,7 @@ struct check_context
};
std
::
string
name
()
const
{
return
"check_context"
;
}
void
apply
(
module
&
p
)
const
{
p
.
insert_instruction
(
p
.
begin
(),
op
{});
}
void
apply
(
module
&
m
)
const
{
m
.
insert_instruction
(
m
.
begin
(),
op
{});
}
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/compile_src.hpp
View file @
11e155c2
...
...
@@ -23,6 +23,8 @@ struct src_compiler
std
::
string
compiler
=
"c++"
;
std
::
string
flags
=
""
;
std
::
string
output
=
""
;
std
::
string
launcher
=
""
;
std
::
string
out_ext
=
".o"
;
std
::
function
<
fs
::
path
(
fs
::
path
)
>
process
=
nullptr
;
std
::
vector
<
char
>
compile
(
const
std
::
vector
<
src_file
>&
srcs
)
const
;
};
...
...
src/include/migraphx/concat_opt.hpp
View file @
11e155c2
...
...
@@ -30,17 +30,20 @@ struct concat_optimization
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
concat_optimization
{
//
std
::
string
name
()
const
;
//
std
::
string
allocate
()
const
;
//
op
::
concat
get_concat
(
const
operation
&
op
)
const
;
};
#else
struct
concat_optimization
{
...
...
@@ -244,6 +247,7 @@ inline const ValueType& any_cast(const concat_optimization& x)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
#endif
...
...
src/include/migraphx/context.hpp
View file @
11e155c2
...
...
@@ -9,6 +9,7 @@
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -37,17 +38,28 @@ void from_value_context(T&, const value&)
{
}
/*
* Type-erased interface for:
*
* struct context
* {
* value to_value() const;
* void from_value(const value& v) ;
* void finish() const;
* };
*
*/
template
<
class
T
>
any_ptr
get_queue_context
(
T
&
)
{
return
{};
}
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
context
{
// (optional)
value
to_value
()
const
;
// (optional)
void
from_value
(
const
value
&
v
);
// (optional)
any_ptr
get_queue
();
//
void
finish
()
const
;
};
#else
struct
context
{
...
...
@@ -124,6 +136,12 @@ struct context
(
*
this
).
private_detail_te_get_handle
().
from_value
(
v
);
}
any_ptr
get_queue
()
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
void
finish
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -145,6 +163,7 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
finish
()
const
=
0
;
};
...
...
@@ -176,6 +195,19 @@ struct context
from_value_context
(
private_detail_te_self
,
v
);
}
template
<
class
T
>
static
auto
private_detail_te_default_get_queue
(
char
,
T
&&
private_detail_te_self
)
->
decltype
(
private_detail_te_self
.
get_queue
())
{
return
private_detail_te_self
.
get_queue
();
}
template
<
class
T
>
static
any_ptr
private_detail_te_default_get_queue
(
float
,
T
&&
private_detail_te_self
)
{
return
get_queue_context
(
private_detail_te_self
);
}
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
...
...
@@ -216,6 +248,12 @@ struct context
private_detail_te_default_from_value
(
char
(
0
),
private_detail_te_value
,
v
);
}
any_ptr
get_queue
()
override
{
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
@@ -282,6 +320,7 @@ inline const ValueType& any_cast(const context& x)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
{
v
=
ctx
.
to_value
();
}
inline
void
migraphx_from_value
(
const
value
&
v
,
context
&
ctx
)
{
ctx
.
from_value
(
v
);
}
...
...
src/include/migraphx/cpp_generator.hpp
View file @
11e155c2
...
...
@@ -68,6 +68,8 @@ struct cpp_generator
void
fmap
(
const
std
::
function
<
std
::
string
(
std
::
string
)
>&
f
);
void
fresult
(
const
std
::
function
<
std
::
string
(
shape
)
>&
f
);
void
add_point_op
(
const
std
::
string
&
op_name
,
const
std
::
string
&
code
);
std
::
string
generate_point_op
(
const
operation
&
op
,
const
std
::
vector
<
std
::
string
>&
args
);
...
...
src/include/migraphx/eliminate_allocation.hpp
View file @
11e155c2
...
...
@@ -19,7 +19,7 @@ struct eliminate_allocation
std
::
string
allocation_op
{};
std
::
size_t
alignment
=
32
;
std
::
string
name
()
const
{
return
"eliminate_allocation"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_common_subexpression.hpp
View file @
11e155c2
...
...
@@ -16,7 +16,7 @@ struct module;
struct
eliminate_common_subexpression
{
std
::
string
name
()
const
{
return
"eliminate_common_subexpression"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_concat.hpp
View file @
11e155c2
...
...
@@ -18,7 +18,7 @@ struct eliminate_concat
{
concat_optimization
concat_opt
;
std
::
string
name
()
const
{
return
"eliminate_concat"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_contiguous.hpp
View file @
11e155c2
...
...
@@ -17,7 +17,7 @@ struct eliminate_contiguous
{
std
::
string
op_name
;
std
::
string
name
()
const
{
return
"eliminate_contiguous"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/eliminate_identity.hpp
View file @
11e155c2
...
...
@@ -18,7 +18,7 @@ struct module;
struct
eliminate_identity
{
std
::
string
name
()
const
{
return
"eliminate_identity"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/filesystem.hpp
View file @
11e155c2
...
...
@@ -3,7 +3,10 @@
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1
#else
...
...
src/include/migraphx/gemm.hpp
View file @
11e155c2
...
...
@@ -3,7 +3,7 @@
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/
shape_for_each
.hpp>
#include <migraphx/
par_for
.hpp>
#include <migraphx/tensor_view.hpp>
namespace
migraphx
{
...
...
@@ -20,8 +20,10 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
shape_for_each
(
cmat
.
get_shape
(),
[
&
](
const
auto
&
c_idx
)
{
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
...
...
src/include/migraphx/generate.hpp
View file @
11e155c2
...
...
@@ -88,16 +88,16 @@ struct xorshift_generator
template
<
class
T
>
auto
generate_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
seed
=
0
)
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
xorshf96_generator
<
T
>
{
seed
});
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
xorshf96_generator
<
T
>
{
seed
});
return
result
;
}
template
<
class
T
>
auto
fill_tensor_data
(
const
migraphx
::
shape
&
s
,
unsigned
long
value
=
0
)
{
auto
result
=
make_shared_array
<
T
>
(
s
.
element
s
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
s
(),
[
=
]
{
return
value
;
});
auto
result
=
make_shared_array
<
T
>
(
s
.
element
_space
());
std
::
generate
(
result
.
get
(),
result
.
get
()
+
s
.
element
_space
(),
[
=
]
{
return
value
;
});
return
result
;
}
...
...
src/include/migraphx/json.hpp
View file @
11e155c2
...
...
@@ -8,6 +8,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
string
to_pretty_json_string
(
const
value
&
val
,
std
::
size_t
indent
=
4
);
std
::
string
to_json_string
(
const
value
&
val
);
value
from_json_string
(
const
std
::
string
&
str
);
value
from_json_string
(
const
char
*
str
,
std
::
size_t
size
);
...
...
src/include/migraphx/make_op.hpp
View file @
11e155c2
...
...
@@ -9,7 +9,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
);
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
);
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template
<
class
Value
>
operation
make_op
(
const
std
::
string
&
name
,
const
Value
&
v
)
{
return
make_op_from_value
(
name
,
v
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/marker.hpp
View file @
11e155c2
...
...
@@ -20,18 +20,22 @@ inline namespace MIGRAPHX_INLINE_NS {
#else
/*
* Type-erased interface for:
*
* struct marker
* {
* void mark_start(instruction_ref ins_ref) ;
* void mark_start(const program& prog) ;
* void mark_stop(instruction_ref ins) ;
* void mark_stop(const program& prog) ;
* };
*
*/
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct
marker
{
//
void
mark_start
(
instruction_ref
ins_ref
);
//
void
mark_start
(
const
program
&
prog
);
//
void
mark_stop
(
instruction_ref
ins
);
//
void
mark_stop
(
const
program
&
prog
);
};
#else
struct
marker
{
...
...
@@ -243,6 +247,7 @@ inline const ValueType& any_cast(const marker& x)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
#endif
...
...
src/include/migraphx/matcher.hpp
100755 → 100644
View file @
11e155c2
...
...
@@ -101,17 +101,17 @@ template <class M>
auto
bind_match
(
M
m
,
std
::
string
name
)
{
return
make_function_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
)
{
if
(
not
ctx
.
has_instruction
(
ins
))
return
nullopt
;
ctx
.
instructions
[
name
]
=
ins
;
}
return
result
;
});
[
=
,
name
=
std
::
move
(
name
)](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
)
{
if
(
not
ctx
.
has_instruction
(
ins
))
return
nullopt
;
ctx
.
instructions
[
name
]
=
ins
;
}
return
result
;
});
}
/// Convert a matcher to a bindable matcher
...
...
@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template
<
class
M
>
struct
basic_matcher
;
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
);
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
struct
basic_matcher
...
...
@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
return
make_b
f
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
return
make_b
asic_fun
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
{
...
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct
matcher_result
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
struct
instruction_container
{
instruction_container
()
=
default
;
instruction_container
(
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
x
)
:
ins_map
(
std
::
move
(
x
))
{
}
instruction_ref
operator
[](
const
std
::
string
&
name
)
const
{
auto
it
=
ins_map
.
find
(
name
);
if
(
it
==
ins_map
.
end
())
MIGRAPHX_THROW
(
"Accessing name that wasn't bound in matcher: "
+
name
);
return
it
->
second
;
}
auto
find
(
const
std
::
string
&
name
)
const
{
return
ins_map
.
find
(
name
);
}
auto
begin
()
const
{
return
ins_map
.
cbegin
();
}
auto
end
()
const
{
return
ins_map
.
cend
();
}
bool
has_instructions_in
(
const
module
&
mod
)
const
{
return
std
::
all_of
(
ins_map
.
begin
(),
ins_map
.
end
(),
[
&
](
auto
&&
p
)
{
return
mod
.
has_instruction
(
p
.
second
);
});
}
private:
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
ins_map
;
};
instruction_container
instructions
;
instruction_ref
result
;
};
...
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result
.
result
=
ins
;
result
.
instructions
=
ctx
.
instructions
;
assert
(
result
.
instructions
.
has_instructions_in
(
mod
));
}
else
{
...
...
@@ -533,10 +579,22 @@ auto skip_output(Ms... ms)
});
}
inline
auto
var
(
std
::
string
s
)
{
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
it
->
second
;
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
[
=
,
s
=
std
::
move
(
s
)](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name_contains
(
const
std
::
string
&
name
)
...
...
@@ -547,7 +605,7 @@ inline auto name_contains(const std::string& name)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
...
...
@@ -696,10 +754,16 @@ auto skip_broadcasts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
))(
ms
...);
}
template
<
class
...
Ms
>
auto
skip_broadcasts_converts
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
return
skip_broadcasts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts
_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
auto
l
=
ins
->
get_literal
();
...
...
src/include/migraphx/memory_coloring.hpp
View file @
11e155c2
...
...
@@ -17,7 +17,7 @@ struct memory_coloring
std
::
string
allocation_op
{};
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/op/as_shape.hpp
View file @
11e155c2
...
...
@@ -36,7 +36,6 @@ struct as_shape
{
return
args
.
front
().
reshape
(
output_shape
);
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/broadcast.hpp
View file @
11e155c2
...
...
@@ -67,7 +67,6 @@ struct broadcast
{
return
args
[
0
].
reshape
(
output_shape
);
}
lifetime
get_lifetime
()
const
{
return
lifetime
::
borrow
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
Prev
1
2
3
4
5
6
7
8
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment