Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a536b16b
Commit
a536b16b
authored
Apr 12, 2018
by
Paul
Browse files
Make operator class extensible
parent
b1e9363f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
186 additions
and
56 deletions
+186
-56
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+15
-0
include/rtg/instruction.hpp
include/rtg/instruction.hpp
+2
-1
include/rtg/operand.hpp
include/rtg/operand.hpp
+123
-4
include/rtg/operators.hpp
include/rtg/operators.hpp
+8
-0
include/rtg/program.hpp
include/rtg/program.hpp
+4
-9
src/program.cpp
src/program.cpp
+2
-2
test/eval_test.cpp
test/eval_test.cpp
+32
-40
No files found.
include/rtg/builtin.hpp
0 → 100644
View file @
a536b16b
#ifndef RTG_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP
namespace
rtg
{
namespace
builtin
{
static
const
char
*
literal
=
"@literal"
;
static
const
char
*
param
=
"@param"
;
}
}
// namespace rtg
#endif
include/rtg/instruction.hpp
View file @
a536b16b
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <rtg/literal.hpp>
#include <rtg/literal.hpp>
#include <rtg/shape.hpp>
#include <rtg/shape.hpp>
#include <rtg/builtin.hpp>
#include <string>
#include <string>
namespace
rtg
{
namespace
rtg
{
...
@@ -16,7 +17,7 @@ struct instruction
...
@@ -16,7 +17,7 @@ struct instruction
{}
{}
instruction
(
literal
l
)
instruction
(
literal
l
)
:
name
(
"
literal
"
),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
:
name
(
builtin
::
literal
),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{}
{}
std
::
string
name
;
std
::
string
name
;
...
...
include/rtg/operand.hpp
View file @
a536b16b
...
@@ -3,16 +3,135 @@
...
@@ -3,16 +3,135 @@
#include <string>
#include <string>
#include <functional>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/shape.hpp>
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
#include <rtg/argument.hpp>
namespace
rtg
{
namespace
rtg
{
struct
operand
/*
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
struct
operand
{
{
std
::
string
name
;
// Constructors
std
::
function
<
shape
(
std
::
vector
<
shape
>
)
>
compute_shape
;
operand
()
=
default
;
std
::
function
<
argument
(
std
::
vector
<
argument
>
)
>
compute
;
template
<
typename
TypeErased_T_
>
operand
(
TypeErased_T_
value
)
:
handle_mem_var_
(
std
::
make_shared
<
handle_type_
<
typename
std
::
remove_reference
<
TypeErased_T_
>::
type
>>
(
std
::
forward
<
TypeErased_T_
>
(
value
)))
{
}
// Assignment
template
<
typename
TypeErased_T_
>
operand
&
operator
=
(
TypeErased_T_
value
)
{
if
(
handle_mem_var_
.
unique
())
*
handle_mem_var_
=
std
::
forward
<
TypeErased_T_
>
(
value
);
else
if
(
!
handle_mem_var_
)
handle_mem_var_
=
std
::
make_shared
<
TypeErased_T_
>
(
std
::
forward
<
TypeErased_T_
>
(
value
));
return
*
this
;
}
std
::
string
name
()
const
{
assert
(
handle_mem_var_
);
return
get_handle_
().
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
assert
(
handle_mem_var_
);
return
get_handle_
().
compute_shape
(
std
::
move
(
input
));
}
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
assert
(
handle_mem_var_
);
return
get_handle_
().
compute
(
std
::
move
(
input
));
}
private:
struct
handle_base_type_
{
virtual
~
handle_base_type_
()
{}
virtual
std
::
shared_ptr
<
handle_base_type_
>
clone
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
=
0
;
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
=
0
;
};
template
<
typename
TypeErased_T_
>
struct
handle_type_
:
handle_base_type_
{
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
0
)
:
value_
(
value
)
{
}
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
0
)
noexcept
:
value_
(
std
::
move
(
value
))
{
}
virtual
std
::
shared_ptr
<
handle_base_type_
>
clone
()
const
{
return
std
::
make_shared
<
handle_type_
>
(
value_
);
}
virtual
std
::
string
name
()
const
{
return
value_
.
name
();
}
virtual
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
return
value_
.
compute_shape
(
std
::
move
(
input
));
}
virtual
argument
compute
(
std
::
vector
<
argument
>
input
)
const
{
return
value_
.
compute
(
std
::
move
(
input
));
}
TypeErased_T_
value_
;
};
template
<
typename
TypeErased_T_
>
struct
handle_type_
<
std
::
reference_wrapper
<
TypeErased_T_
>>
:
handle_type_
<
TypeErased_T_
&>
{
handle_type_
(
std
::
reference_wrapper
<
TypeErased_T_
>
ref
)
:
handle_type_
<
TypeErased_T_
&>
(
ref
.
get
())
{
}
};
const
handle_base_type_
&
get_handle_
()
const
{
return
*
handle_mem_var_
;
}
handle_base_type_
&
get_handle_
()
{
if
(
!
handle_mem_var_
.
unique
())
handle_mem_var_
=
handle_mem_var_
->
clone
();
return
*
handle_mem_var_
;
}
std
::
shared_ptr
<
handle_base_type_
>
handle_mem_var_
;
};
};
}
}
...
...
include/rtg/operators.hpp
0 → 100644
View file @
a536b16b
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
namespace
rtg
{
}
// namespace rtg
#endif
include/rtg/program.hpp
View file @
a536b16b
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include <unordered_map>
#include <unordered_map>
#include <rtg/instruction.hpp>
#include <rtg/instruction.hpp>
#include <rtg/operand.hpp>
#include <rtg/operand.hpp>
#include <rtg/builtin.hpp>
namespace
rtg
{
namespace
rtg
{
...
@@ -27,18 +28,13 @@ struct program
...
@@ -27,18 +28,13 @@ struct program
instruction
*
add_parameter
(
std
::
string
name
,
shape
s
)
instruction
*
add_parameter
(
std
::
string
name
,
shape
s
)
{
{
instructions
.
push_back
({
"
param
:"
+
std
::
move
(
name
),
s
,
{}});
instructions
.
push_back
({
builtin
::
param
+
std
::
move
(
name
),
s
,
{}});
return
std
::
addressof
(
instructions
.
back
());
return
std
::
addressof
(
instructions
.
back
());
}
}
template
<
class
Op
,
class
Shape
>
void
add_operator
(
operand
op
)
void
add_operator
(
std
::
string
name
,
Op
op
,
Shape
s
)
{
{
operand
result
;
ops
.
emplace
(
op
.
name
(),
op
);
result
.
name
=
name
;
result
.
compute
=
op
;
result
.
compute_shape
=
s
;
ops
.
emplace
(
name
,
result
);
}
}
literal
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
literal
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
...
@@ -48,7 +44,6 @@ private:
...
@@ -48,7 +44,6 @@ private:
std
::
list
<
instruction
>
instructions
;
std
::
list
<
instruction
>
instructions
;
std
::
unordered_map
<
std
::
string
,
operand
>
ops
;
std
::
unordered_map
<
std
::
string
,
operand
>
ops
;
};
};
}
}
...
...
src/program.cpp
View file @
a536b16b
...
@@ -10,11 +10,11 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -10,11 +10,11 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
argument
result
;
argument
result
;
for
(
auto
&
ins
:
instructions
)
for
(
auto
&
ins
:
instructions
)
{
{
if
(
ins
.
name
==
"
literal
"
)
if
(
ins
.
name
==
builtin
::
literal
)
{
{
result
=
ins
.
lit
.
get_argument
();
result
=
ins
.
lit
.
get_argument
();
}
}
else
if
(
starts_with
(
ins
.
name
,
"
param
:"
))
else
if
(
starts_with
(
ins
.
name
,
builtin
::
param
))
{
{
result
=
params
.
at
(
ins
.
name
.
substr
(
6
));
result
=
params
.
at
(
ins
.
name
.
substr
(
6
));
}
}
...
...
test/eval_test.cpp
View file @
a536b16b
...
@@ -4,28 +4,39 @@
...
@@ -4,28 +4,39 @@
#include <rtg/shape.hpp>
#include <rtg/shape.hpp>
#include "test.hpp"
#include "test.hpp"
void
literal_test
()
{
rtg
::
program
p
;
p
.
add_operator
(
"sum"
,
[](
std
::
vector
<
rtg
::
argument
>
args
)
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
args
[
0
].
visit_at
([
&
](
auto
x
)
{
struct
sum_op
args
[
1
].
visit_at
([
&
](
auto
y
)
{
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
std
::
string
name
()
const
});
{
return
"sum"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
args
[
0
].
visit_at
([
&
](
auto
x
)
{
args
[
1
].
visit_at
([
&
](
auto
y
)
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
});
return
result
;
});
},
return
result
;
[](
std
::
vector
<
rtg
::
shape
>
inputs
)
{
}
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
return
inputs
.
front
();
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
inputs
)
const
}
{
);
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
return
inputs
.
front
();
}
};
void
literal_test
()
{
rtg
::
program
p
;
p
.
add_operator
(
sum_op
{});
auto
one
=
p
.
add_literal
(
1
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
two
=
p
.
add_literal
(
2
);
...
@@ -37,26 +48,7 @@ void literal_test() {
...
@@ -37,26 +48,7 @@ void literal_test() {
void
param_test
()
{
void
param_test
()
{
rtg
::
program
p
;
rtg
::
program
p
;
p
.
add_operator
(
"sum"
,
p
.
add_operator
(
sum_op
{});
[](
std
::
vector
<
rtg
::
argument
>
args
)
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
args
[
0
].
visit_at
([
&
](
auto
x
)
{
args
[
1
].
visit_at
([
&
](
auto
y
)
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
});
return
result
;
},
[](
std
::
vector
<
rtg
::
shape
>
inputs
)
{
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
return
inputs
.
front
();
}
);
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int_type
});
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int_type
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment