Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fb75dfaf
Commit
fb75dfaf
authored
Aug 14, 2018
by
Paul
Browse files
Only use no-cache on jenkins
parents
e596eec2
f0604d78
Changes
122
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
843 additions
and
31 deletions
+843
-31
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+22
-0
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+40
-0
src/generate.cpp
src/generate.cpp
+27
-0
src/include/migraph/argument.hpp
src/include/migraph/argument.hpp
+6
-12
src/include/migraph/auto_any_cast.hpp
src/include/migraph/auto_any_cast.hpp
+37
-0
src/include/migraph/auto_contiguous.hpp
src/include/migraph/auto_contiguous.hpp
+19
-0
src/include/migraph/builtin.hpp
src/include/migraph/builtin.hpp
+44
-0
src/include/migraph/check_context.hpp
src/include/migraph/check_context.hpp
+30
-0
src/include/migraph/check_shapes.hpp
src/include/migraph/check_shapes.hpp
+125
-0
src/include/migraph/context.hpp
src/include/migraph/context.hpp
+196
-0
src/include/migraph/dead_code_elimination.hpp
src/include/migraph/dead_code_elimination.hpp
+19
-0
src/include/migraph/dfor.hpp
src/include/migraph/dfor.hpp
+4
-4
src/include/migraph/erase.hpp
src/include/migraph/erase.hpp
+6
-4
src/include/migraph/errors.hpp
src/include/migraph/errors.hpp
+7
-7
src/include/migraph/fallthrough.hpp
src/include/migraph/fallthrough.hpp
+14
-0
src/include/migraph/float_equal.hpp
src/include/migraph/float_equal.hpp
+4
-4
src/include/migraph/functional.hpp
src/include/migraph/functional.hpp
+39
-0
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+26
-0
src/include/migraph/instruction.hpp
src/include/migraph/instruction.hpp
+165
-0
src/include/migraph/instruction_ref.hpp
src/include/migraph/instruction_ref.hpp
+13
-0
No files found.
src/auto_contiguous.cpp
0 → 100644
View file @
fb75dfaf
#include <migraph/auto_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
namespace
migraph
{
void
auto_contiguous
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
shape
s
=
ins
->
result
;
if
(
not
s
.
standard
())
{
auto
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
contiguous
{},
ins
);
p
.
replace_instruction
(
ins
,
c
);
}
}
}
}
// namespace migraph
src/dead_code_elimination.cpp
0 → 100644
View file @
fb75dfaf
#include <migraph/dead_code_elimination.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/functional.hpp>
namespace
migraph
{
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
auto
last
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
// Skip the first instruction, since we always process the previous
// instruction
if
(
ins
==
p
.
begin
())
continue
;
const
auto
i
=
std
::
prev
(
ins
);
// Skip instruction with empty shape as output
if
(
i
->
result
.
elements
()
==
0
)
continue
;
// Skip the last instruction
if
(
i
==
last
)
break
;
fix
([
&
](
auto
self
,
auto
leaf
)
{
assert
(
p
.
has_instruction
(
leaf
));
if
(
leaf
->
output
.
empty
())
{
auto
args
=
leaf
->
arguments
;
leaf
->
clear_arguments
();
p
.
move_instruction
(
leaf
,
p
.
end
());
for
(
auto
arg
:
args
)
self
(
arg
);
}
})(
i
);
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
}
}
// namespace migraph
src/generate.cpp
0 → 100644
View file @
fb75dfaf
#include <migraph/generate.hpp>
namespace
migraph
{
argument
generate_argument
(
shape
s
,
std
::
mt19937
::
result_type
seed
)
{
argument
result
;
s
.
visit_type
([
&
](
auto
as
)
{
using
type
=
typename
decltype
(
as
)
::
type
;
auto
v
=
generate_tensor_data
<
type
>
(
s
,
seed
);
result
=
{
s
,
[
v
]()
mutable
{
return
reinterpret_cast
<
char
*>
(
v
.
data
());
}};
});
return
result
;
}
literal
generate_literal
(
shape
s
,
std
::
mt19937
::
result_type
seed
)
{
literal
result
;
s
.
visit_type
([
&
](
auto
as
)
{
using
type
=
typename
decltype
(
as
)
::
type
;
auto
v
=
generate_tensor_data
<
type
>
(
s
,
seed
);
result
=
{
s
,
v
};
});
return
result
;
}
}
// namespace migraph
src/include/
rtg
/argument.hpp
→
src/include/
migraph
/argument.hpp
View file @
fb75dfaf
#ifndef
RTG_GUARD_RTG
LIB_ARGUMENT_HPP
#define
RTG_GUARD_RTG
LIB_ARGUMENT_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_ARGUMENT_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_ARGUMENT_HPP
#include <
rtg
/shape.hpp>
#include <
rtg
/raw_data.hpp>
#include <
migraph
/shape.hpp>
#include <
migraph
/raw_data.hpp>
#include <functional>
namespace
rtg
{
namespace
migraph
{
/**
* @brief Arguments passed to instructions
...
...
@@ -39,16 +39,10 @@ struct argument : raw_data<argument>
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
template
<
class
T
>
T
*
cast
()
const
{
return
reinterpret_cast
<
T
*>
(
this
->
data
());
}
private:
shape
m_shape
;
};
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/migraph/auto_any_cast.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
namespace
migraph
{
namespace
detail
{
template
<
class
U
>
void
any_cast
()
{
}
template
<
class
T
>
struct
auto_any_caster
{
T
&
x
;
template
<
class
U
>
operator
U
&
()
{
return
any_cast
<
U
>
(
x
);
}
operator
T
&
()
{
return
x
;
}
};
}
// namespace detail
template
<
class
T
>
detail
::
auto_any_caster
<
T
>
auto_any_cast
(
T
&
x
)
{
return
{
x
};
}
}
// namespace migraph
#endif
src/include/migraph/auto_contiguous.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
auto_contiguous
{
std
::
string
name
()
const
{
return
"auto_contiguous"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/
rtg
/builtin.hpp
→
src/include/
migraph
/builtin.hpp
View file @
fb75dfaf
#ifndef
RTG
_GUARD_BUILTIN_HPP
#define
RTG
_GUARD_BUILTIN_HPP
#ifndef
MIGRAPH
_GUARD_BUILTIN_HPP
#define
MIGRAPH
_GUARD_BUILTIN_HPP
#include <rtg/operation.hpp>
#include <rtg/errors.hpp>
#include <migraph/context.hpp>
#include <migraph/errors.hpp>
#include <migraph/argument.hpp>
namespace
rtg
{
namespace
migraph
{
namespace
builtin
{
struct
literal
{
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG
_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG
_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
MIGRAPH
_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH
_THROW
(
"builtin"
);
}
};
struct
outline
{
shape
s
;
std
::
string
name
()
const
{
return
"@outline"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG_THROW
(
"builtin"
)
;
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG
_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
return
s
;
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH
_THROW
(
"builtin"
);
}
};
struct
param
{
std
::
string
parameter
;
std
::
string
name
()
const
{
return
"@param"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
RTG
_THROW
(
"builtin"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG
_THROW
(
"builtin"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
MIGRAPH
_THROW
(
"builtin"
);
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
MIGRAPH
_THROW
(
"builtin"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
param
&
op
)
{
os
<<
op
.
name
()
<<
":"
<<
op
.
parameter
;
...
...
@@ -38,6 +39,6 @@ struct param
}
// namespace builtin
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/migraph/check_context.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#include <migraph/program.hpp>
namespace
migraph
{
template
<
class
T
>
struct
check_context
{
struct
op
{
std
::
string
name
()
const
{
return
"check_context"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
return
{};
}
argument
compute
(
context
&
ctx
,
shape
,
std
::
vector
<
argument
>
)
const
{
T
*
x
=
any_cast
<
T
>
(
&
ctx
);
if
(
x
==
nullptr
)
MIGRAPH_THROW
(
std
::
string
(
"Unexpected context type: "
)
+
ctx
.
type_id
().
name
());
return
{};
}
};
std
::
string
name
()
const
{
return
"check_context"
;
}
void
apply
(
program
&
p
)
const
{
p
.
insert_instruction
(
p
.
begin
(),
op
{});
}
};
}
// namespace migraph
#endif
src/include/migraph/check_shapes.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraph/shape.hpp>
#include <algorithm>
namespace
migraph
{
struct
check_shapes
{
const
std
::
vector
<
shape
>*
shapes
;
const
std
::
string
name
;
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
shapes
(
&
s
)
{}
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
)
:
shapes
(
&
s
),
name
(
op
.
name
())
{
}
std
::
string
prefix
()
const
{
if
(
name
.
empty
())
return
""
;
else
return
name
+
": "
;
}
const
check_shapes
&
has
(
std
::
size_t
n
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
shapes
->
size
()
!=
n
)
MIGRAPH_THROW
(
prefix
()
+
"Wrong number of arguments: expected "
+
std
::
to_string
(
n
)
+
" but given "
+
std
::
to_string
(
shapes
->
size
()));
return
*
this
;
}
const
check_shapes
&
only_dims
(
std
::
size_t
n
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
!
shapes
->
empty
())
{
if
(
shapes
->
front
().
lens
().
size
()
!=
n
)
MIGRAPH_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
}
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPH_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
}
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPH_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
}
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
();
}))
MIGRAPH_THROW
(
prefix
()
+
"Dimensions do not match"
);
return
*
this
;
}
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
().
size
();
}))
MIGRAPH_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
const
check_shapes
&
standard
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPH_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
}
const
check_shapes
&
packed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPH_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
}
const
check_shapes
&
not_transposed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPH_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
}
const
check_shapes
&
not_broadcasted
()
const
{
// if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
// MIGRAPH_THROW(prefix() + "Shapes are broadcasted");
return
*
this
;
}
template
<
class
F
>
bool
same
(
F
f
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
shapes
->
empty
())
return
true
;
auto
&&
key
=
f
(
shapes
->
front
());
return
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
f
(
s
)
==
key
;
});
}
template
<
class
Predicate
>
bool
all_of
(
Predicate
p
)
const
{
assert
(
shapes
!=
nullptr
);
return
std
::
all_of
(
shapes
->
begin
(),
shapes
->
end
(),
p
);
}
};
}
// namespace migraph
#endif
src/include/migraph/context.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace
migraph
{
#ifdef DOXYGEN
/// A context is used to store internal data for a `target`. A context is
/// constructed by a target during compilation and passed to the operations
/// during `eval`.
struct
context
{
};
#else
/*
* Type-erased interface for:
*
* struct context
* {
* };
*
*/
struct
context
{
// Constructors
context
()
=
default
;
template
<
typename
PrivateDetailTypeErasedT
>
context
(
PrivateDetailTypeErasedT
value
)
:
private_detail_te_handle_mem_var
(
std
::
make_shared
<
private_detail_te_handle_type
<
typename
std
::
remove_reference
<
PrivateDetailTypeErasedT
>::
type
>>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
)))
{
}
// Assignment
template
<
typename
PrivateDetailTypeErasedT
>
context
&
operator
=
(
PrivateDetailTypeErasedT
value
)
{
if
(
private_detail_te_handle_mem_var
.
unique
())
*
private_detail_te_handle_mem_var
=
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
);
else
if
(
!
private_detail_te_handle_mem_var
)
private_detail_te_handle_mem_var
=
std
::
make_shared
<
PrivateDetailTypeErasedT
>
(
std
::
forward
<
PrivateDetailTypeErasedT
>
(
value
));
return
*
this
;
}
// Cast
template
<
typename
PrivateDetailTypeErasedT
>
PrivateDetailTypeErasedT
*
any_cast
()
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
template
<
typename
PrivateDetailTypeErasedT
>
const
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
*
any_cast
()
const
{
return
private_detail_te_get_handle
().
type
()
==
typeid
(
PrivateDetailTypeErasedT
)
?
std
::
addressof
(
static_cast
<
const
private_detail_te_handle_type
<
typename
std
::
remove_cv
<
PrivateDetailTypeErasedT
>::
type
>&>
(
private_detail_te_get_handle
())
.
private_detail_te_value
)
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
private:
struct
private_detail_te_handle_base_type
{
virtual
~
private_detail_te_handle_base_type
()
{}
virtual
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
=
0
;
virtual
const
std
::
type_info
&
type
()
const
=
0
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
>::
type
*
=
nullptr
)
:
private_detail_te_value
(
value
)
{
}
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
clone
()
const
override
{
return
std
::
make_shared
<
private_detail_te_handle_type
>
(
private_detail_te_value
);
}
const
std
::
type_info
&
type
()
const
override
{
return
typeid
(
private_detail_te_value
);
}
PrivateDetailTypeErasedT
private_detail_te_value
;
};
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
<
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>>
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
{
private_detail_te_handle_type
(
std
::
reference_wrapper
<
PrivateDetailTypeErasedT
>
ref
)
:
private_detail_te_handle_type
<
PrivateDetailTypeErasedT
&>
(
ref
.
get
())
{
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
std
::
shared_ptr
<
private_detail_te_handle_base_type
>
private_detail_te_handle_mem_var
;
};
template
<
typename
ValueType
>
inline
const
ValueType
*
any_cast
(
const
context
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
*
any_cast
(
context
*
x
)
{
return
x
->
any_cast
<
ValueType
>
();
}
template
<
typename
ValueType
>
inline
ValueType
&
any_cast
(
context
&
x
)
{
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
template
<
typename
ValueType
>
inline
const
ValueType
&
any_cast
(
const
context
&
x
)
{
const
auto
*
y
=
x
.
any_cast
<
typename
std
::
remove_reference
<
ValueType
>::
type
>
();
if
(
y
==
nullptr
)
throw
std
::
bad_cast
();
return
*
y
;
}
#endif
}
// namespace migraph
#endif
src/include/migraph/dead_code_elimination.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace
migraph
{
struct
program
;
struct
dead_code_elimination
{
std
::
string
name
()
const
{
return
"dead_code_elimination"
;
}
void
apply
(
program
&
p
)
const
;
};
}
// namespace migraph
#endif
src/include/
rtg
/dfor.hpp
→
src/include/
migraph
/dfor.hpp
View file @
fb75dfaf
#ifndef
RTG_GUARD_RTG
LIB_DFOR_HPP
#define
RTG_GUARD_RTG
LIB_DFOR_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_DFOR_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_DFOR_HPP
namespace
rtg
{
namespace
migraph
{
// Multidimensional for loop
inline
auto
dfor
()
...
...
@@ -20,6 +20,6 @@ auto dfor(T x, Ts... xs)
};
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/erase.hpp
→
src/include/
migraph
/erase.hpp
View file @
fb75dfaf
#ifndef
RTG
_GUARD_ERASE_HPP
#define
RTG
_GUARD_ERASE_HPP
#ifndef
MIGRAPH
_GUARD_ERASE_HPP
#define
MIGRAPH
_GUARD_ERASE_HPP
namespace
rtg
{
#include <algorithm>
namespace
migraph
{
/**
* @brief Erase all elements from a container
...
...
@@ -29,6 +31,6 @@ auto erase_if(R&& r, P&& pred)
return
r
.
erase
(
std
::
remove_if
(
r
.
begin
(),
r
.
end
(),
pred
),
r
.
end
());
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/errors.hpp
→
src/include/
migraph
/errors.hpp
View file @
fb75dfaf
#ifndef
RTG
_GUARD_ERRORS_HPP
#define
RTG
_GUARD_ERRORS_HPP
#ifndef
MIGRAPH
_GUARD_ERRORS_HPP
#define
MIGRAPH
_GUARD_ERRORS_HPP
#include <exception>
#include <stdexcept>
#include <string>
namespace
rtg
{
namespace
migraph
{
/// Represents exceptions that can be thrown by
rtg
lib
/// Represents exceptions that can be thrown by
migraph
lib
struct
exception
:
std
::
runtime_error
{
exception
(
std
::
string
msg
=
""
)
:
std
::
runtime_error
(
msg
)
{}
...
...
@@ -41,9 +41,9 @@ inline std::string make_source_context(const std::string& file, int line)
/**
* @brief Throw an exception with context information
*/
#define
RTG
_THROW(...) \
throw
rtg
::make_exception(
rtg
::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
#define
MIGRAPH
_THROW(...) \
throw
migraph
::make_exception(
migraph
::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/migraph/fallthrough.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_FALLTHROUGH_HPP
#define MIGRAPH_GUARD_FALLTHROUGH_HPP
namespace
migraph
{
#ifdef __clang__
#define MIGRAPH_FALLTHROUGH [[clang::fallthrough]]
#else
#define MIGRAPH_FALLTHROUGH
#endif
}
// namespace migraph
#endif
src/include/
rtg
/float_equal.hpp
→
src/include/
migraph
/float_equal.hpp
View file @
fb75dfaf
#ifndef
RTG_GUARD_RTG
LIB_FLOAT_EQUAL_HPP
#define
RTG_GUARD_RTG
LIB_FLOAT_EQUAL_HPP
#ifndef
MIGRAPH_GUARD_MIGRAPH
LIB_FLOAT_EQUAL_HPP
#define
MIGRAPH_GUARD_MIGRAPH
LIB_FLOAT_EQUAL_HPP
#include <algorithm>
#include <cmath>
...
...
@@ -8,7 +8,7 @@
#include <iso646.h>
#endif
namespace
rtg
{
namespace
migraph
{
template
<
class
...
Ts
>
using
common_type
=
typename
std
::
common_type
<
Ts
...
>::
type
;
...
...
@@ -32,6 +32,6 @@ struct float_equal_fn
static
constexpr
float_equal_fn
float_equal
{};
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/migraph/functional.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility>
namespace
migraph
{
namespace
detail
{
template
<
class
R
,
class
F
>
struct
fix_f
{
F
f
;
template
<
class
...
Ts
>
R
operator
()(
Ts
&&
...
xs
)
const
{
return
f
(
*
this
,
std
::
forward
<
Ts
>
(
xs
)...);
}
};
}
// namespace detail
/// Implements a fix-point combinator
template
<
class
R
,
class
F
>
detail
::
fix_f
<
R
,
F
>
fix
(
F
f
)
{
return
{
f
};
}
template
<
class
F
>
auto
fix
(
F
f
)
{
return
fix
<
void
>
(
f
);
}
}
// namespace migraph
#endif
src/include/migraph/generate.hpp
0 → 100644
View file @
fb75dfaf
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraph/argument.hpp>
#include <migraph/literal.hpp>
#include <random>
namespace
migraph
{
template
<
class
T
>
std
::
vector
<
T
>
generate_tensor_data
(
migraph
::
shape
s
,
std
::
mt19937
::
result_type
seed
=
0
)
{
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
mt19937
engine
{
seed
};
std
::
uniform_real_distribution
<>
dist
;
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
&
]
{
return
dist
(
engine
);
});
return
result
;
}
argument
generate_argument
(
shape
s
,
std
::
mt19937
::
result_type
seed
=
0
);
literal
generate_literal
(
shape
s
,
std
::
mt19937
::
result_type
seed
=
0
);
}
// namespace migraph
#endif
src/include/
rtg
/instruction.hpp
→
src/include/
migraph
/instruction.hpp
View file @
fb75dfaf
#ifndef RTG_GUARD_RTGLIB_INSTRUCTION_HPP
#define RTG_GUARD_RTGLIB_INSTRUCTION_HPP
#include <rtg/literal.hpp>
#include <rtg/shape.hpp>
#include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <rtg/erase.hpp>
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#include <migraph/literal.hpp>
#include <migraph/shape.hpp>
#include <migraph/builtin.hpp>
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <string>
namespace
rtg
{
namespace
migraph
{
shape
compute_shape
(
operation
op
,
std
::
vector
<
instruction_ref
>
args
);
...
...
@@ -37,23 +38,33 @@ struct instruction
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
ins
->
replace
(
compute_shape
(
ins
->
op
,
ins
->
arguments
));
assert
(
ins
->
op
.
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
}
void
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
clear_arguments
();
arguments
=
std
::
move
(
args
);
}
void
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
}
void
clear_arguments
()
{
for
(
auto
&&
arg
:
arguments
)
{
rtg
::
erase
(
arg
->
output
,
*
this
);
arg
->
remove_
output
(
*
this
);
}
arguments
.
clear
();
}
friend
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
)
...
...
@@ -61,25 +72,64 @@ struct instruction
return
std
::
addressof
(
i
)
==
std
::
addressof
(
*
ref
);
}
bool
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
output
.
begin
(),
i
->
output
.
end
(),
*
this
);
return
self
!=
i
->
output
.
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
bool
valid
()
const
{
return
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
i
->
arguments
.
end
();
})
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
output
.
begin
(),
i
->
output
.
end
(),
*
this
)
!=
i
->
output
.
end
();
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
{
computed
=
lit
.
get_shape
();
}
else
if
(
op
.
name
()
==
"@param"
)
{
computed
=
result
;
}
else
{
try
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
{
return
false
;
}
}
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
arguments
.
begin
(),
i
->
arguments
.
end
(),
*
this
)
!=
i
->
arguments
.
end
();
});
}
shape
get_shape
()
const
{
return
result
;
}
friend
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
friend
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
friend
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
void
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
template
<
class
T
>
void
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
operation
op
;
shape
result
;
std
::
vector
<
instruction_ref
>
output
;
...
...
@@ -90,7 +140,14 @@ struct instruction
inline
void
backreference
(
instruction_ref
ref
)
{
for
(
auto
&&
arg
:
ref
->
arguments
)
arg
->
output
.
push_back
(
ref
);
arg
->
add_output
(
ref
);
}
inline
void
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
)
{
ins
->
replace_argument
(
old
,
new_ins
);
backreference
(
ins
);
ins
->
recompute_shape
();
}
// TODO: Move to a cpp file
...
...
@@ -103,6 +160,6 @@ inline shape compute_shape(operation op, std::vector<instruction_ref> args)
return
op
.
compute_shape
(
shapes
);
}
}
// namespace
rtg
}
// namespace
migraph
#endif
src/include/
rtg
/instruction_ref.hpp
→
src/include/
migraph
/instruction_ref.hpp
View file @
fb75dfaf
#ifndef
RTG
_GUARD_INSTRUCTION_REF_HPP
#define
RTG
_GUARD_INSTRUCTION_REF_HPP
#ifndef
MIGRAPH
_GUARD_INSTRUCTION_REF_HPP
#define
MIGRAPH
_GUARD_INSTRUCTION_REF_HPP
#include <list>
namespace
rtg
{
namespace
migraph
{
struct
instruction
;
using
instruction_ref
=
std
::
list
<
instruction
>::
iterator
;
}
// namespace
rtg
}
// namespace
migraph
#endif
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment