Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1b18e9d0
Commit
1b18e9d0
authored
Jun 08, 2018
by
Paul
Browse files
Return an argument instead of a literal from eval
parent
8ff0905c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
41 deletions
+53
-41
src/include/rtg/program.hpp
src/include/rtg/program.hpp
+1
-1
src/include/rtg/raw_data.hpp
src/include/rtg/raw_data.hpp
+32
-22
src/include/rtg/tensor_view.hpp
src/include/rtg/tensor_view.hpp
+18
-16
src/program.cpp
src/program.cpp
+2
-2
No files found.
src/include/rtg/program.hpp
View file @
1b18e9d0
...
@@ -62,7 +62,7 @@ struct program
...
@@ -62,7 +62,7 @@ struct program
shape
get_parameter_shape
(
std
::
string
name
);
shape
get_parameter_shape
(
std
::
string
name
);
literal
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
argument
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
...
...
src/include/rtg/raw_data.hpp
View file @
1b18e9d0
...
@@ -6,6 +6,11 @@
...
@@ -6,6 +6,11 @@
namespace
rtg
{
namespace
rtg
{
#define RTG_REQUIRES(...) class=typename std::enable_if<(__VA_ARGS__)>::type
struct
raw_data_base
{};
/**
/**
* @brief Provides a base class for common operations with raw buffer
* @brief Provides a base class for common operations with raw buffer
*
*
...
@@ -15,29 +20,8 @@ namespace rtg {
...
@@ -15,29 +20,8 @@ namespace rtg {
*
*
*/
*/
template
<
class
Derived
>
template
<
class
Derived
>
struct
raw_data
struct
raw_data
:
raw_data_base
{
{
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
{
auto
&&
xshape
=
x
.
get_shape
();
auto
&&
yshape
=
y
.
get_shape
();
bool
result
=
x
.
empty
()
&&
y
.
empty
();
if
(
not
result
&&
xshape
==
yshape
)
{
auto
&&
xbuffer
=
x
.
data
();
auto
&&
ybuffer
=
y
.
data
();
// TODO: Dont use tensor view for single values
xshape
.
visit_type
([
&
](
auto
as
)
{
auto
xview
=
make_view
(
xshape
,
as
.
from
(
xbuffer
));
auto
yview
=
make_view
(
yshape
,
as
.
from
(
ybuffer
));
result
=
xview
==
yview
;
});
}
return
result
;
}
friend
bool
operator
!=
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
!
(
x
==
y
);
}
template
<
class
Stream
>
template
<
class
Stream
>
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
{
{
...
@@ -114,6 +98,32 @@ struct raw_data
...
@@ -114,6 +98,32 @@ struct raw_data
auto_cast
get
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
auto_cast
get
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
};
};
template
<
class
T
,
class
U
,
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}),
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
==
(
const
T
&
x
,
const
U
&
y
)
{
auto
&&
xshape
=
x
.
get_shape
();
auto
&&
yshape
=
y
.
get_shape
();
bool
result
=
x
.
empty
()
&&
y
.
empty
();
if
(
not
result
&&
xshape
==
yshape
)
{
auto
&&
xbuffer
=
x
.
data
();
auto
&&
ybuffer
=
y
.
data
();
// TODO: Dont use tensor view for single values
xshape
.
visit_type
([
&
](
auto
as
)
{
auto
xview
=
make_view
(
xshape
,
as
.
from
(
xbuffer
));
auto
yview
=
make_view
(
yshape
,
as
.
from
(
ybuffer
));
result
=
xview
==
yview
;
});
}
return
result
;
}
template
<
class
T
,
class
U
,
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}),
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
return
!
(
x
==
y
);
}
namespace
detail
{
namespace
detail
{
template
<
class
V
,
class
...
Ts
>
template
<
class
V
,
class
...
Ts
>
void
visit_all_impl
(
const
shape
&
s
,
V
&&
v
,
Ts
&&
...
xs
)
void
visit_all_impl
(
const
shape
&
s
,
V
&&
v
,
Ts
&&
...
xs
)
...
...
src/include/rtg/tensor_view.hpp
View file @
1b18e9d0
...
@@ -103,22 +103,6 @@ struct tensor_view
...
@@ -103,22 +103,6 @@ struct tensor_view
return
m_data
+
this
->
size
();
return
m_data
+
this
->
size
();
}
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
if
(
x
.
m_shape
==
y
.
m_shape
)
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
m_shape
.
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
return
true
;
}
return
false
;
}
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
return
!
(
x
==
y
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
if
(
!
x
.
empty
())
if
(
!
x
.
empty
())
...
@@ -137,6 +121,24 @@ struct tensor_view
...
@@ -137,6 +121,24 @@ struct tensor_view
shape
m_shape
;
shape
m_shape
;
};
};
template
<
class
T
,
class
U
>
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
if
(
x
.
get_shape
()
==
y
.
get_shape
())
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
return
true
;
}
return
false
;
}
template
<
class
T
,
class
U
>
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
return
!
(
x
==
y
);
}
template
<
class
T
>
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
{
{
...
...
src/program.cpp
View file @
1b18e9d0
...
@@ -112,7 +112,7 @@ void program::compile(const target& t)
...
@@ -112,7 +112,7 @@ void program::compile(const target& t)
RTG_THROW
(
"Invalid program from compilation"
);
RTG_THROW
(
"Invalid program from compilation"
);
}
}
literal
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
{
assert
(
this
->
validate
()
!=
impl
->
instructions
.
end
());
assert
(
this
->
validate
()
!=
impl
->
instructions
.
end
());
std
::
unordered_map
<
const
instruction
*
,
argument
>
results
;
std
::
unordered_map
<
const
instruction
*
,
argument
>
results
;
...
@@ -142,7 +142,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -142,7 +142,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
}
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
}
return
literal
{
result
.
get_shape
(),
result
.
data
()}
;
return
result
;
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment