Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1b18e9d0
"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "6c7d99928251e03249ac2c65006c7452f5676bb7"
Commit
1b18e9d0
authored
Jun 08, 2018
by
Paul
Browse files
Return an argument instead of a literal from eval
parent
8ff0905c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
41 deletions
+53
-41
src/include/rtg/program.hpp
src/include/rtg/program.hpp
+1
-1
src/include/rtg/raw_data.hpp
src/include/rtg/raw_data.hpp
+32
-22
src/include/rtg/tensor_view.hpp
src/include/rtg/tensor_view.hpp
+18
-16
src/program.cpp
src/program.cpp
+2
-2
No files found.
src/include/rtg/program.hpp
View file @
1b18e9d0
...
...
@@ -62,7 +62,7 @@ struct program
shape
get_parameter_shape
(
std
::
string
name
);
literal
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
argument
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
);
...
...
src/include/rtg/raw_data.hpp
View file @
1b18e9d0
...
...
@@ -6,6 +6,11 @@
namespace
rtg
{
#define RTG_REQUIRES(...) class=typename std::enable_if<(__VA_ARGS__)>::type
struct
raw_data_base
{};
/**
* @brief Provides a base class for common operations with raw buffer
*
...
...
@@ -15,29 +20,8 @@ namespace rtg {
*
*/
template
<
class
Derived
>
struct
raw_data
struct
raw_data
:
raw_data_base
{
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
{
auto
&&
xshape
=
x
.
get_shape
();
auto
&&
yshape
=
y
.
get_shape
();
bool
result
=
x
.
empty
()
&&
y
.
empty
();
if
(
not
result
&&
xshape
==
yshape
)
{
auto
&&
xbuffer
=
x
.
data
();
auto
&&
ybuffer
=
y
.
data
();
// TODO: Dont use tensor view for single values
xshape
.
visit_type
([
&
](
auto
as
)
{
auto
xview
=
make_view
(
xshape
,
as
.
from
(
xbuffer
));
auto
yview
=
make_view
(
yshape
,
as
.
from
(
ybuffer
));
result
=
xview
==
yview
;
});
}
return
result
;
}
friend
bool
operator
!=
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
!
(
x
==
y
);
}
template
<
class
Stream
>
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
{
...
...
@@ -114,6 +98,32 @@ struct raw_data
auto_cast
get
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
};
template
<
class
T
,
class
U
,
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}),
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
==
(
const
T
&
x
,
const
U
&
y
)
{
auto
&&
xshape
=
x
.
get_shape
();
auto
&&
yshape
=
y
.
get_shape
();
bool
result
=
x
.
empty
()
&&
y
.
empty
();
if
(
not
result
&&
xshape
==
yshape
)
{
auto
&&
xbuffer
=
x
.
data
();
auto
&&
ybuffer
=
y
.
data
();
// TODO: Dont use tensor view for single values
xshape
.
visit_type
([
&
](
auto
as
)
{
auto
xview
=
make_view
(
xshape
,
as
.
from
(
xbuffer
));
auto
yview
=
make_view
(
yshape
,
as
.
from
(
ybuffer
));
result
=
xview
==
yview
;
});
}
return
result
;
}
template
<
class
T
,
class
U
,
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
T
>{}),
RTG_REQUIRES
(
std
::
is_base_of
<
raw_data_base
,
U
>
{})
>
bool
operator
!=
(
const
T
&
x
,
const
U
&
y
)
{
return
!
(
x
==
y
);
}
namespace
detail
{
template
<
class
V
,
class
...
Ts
>
void
visit_all_impl
(
const
shape
&
s
,
V
&&
v
,
Ts
&&
...
xs
)
...
...
src/include/rtg/tensor_view.hpp
View file @
1b18e9d0
...
...
@@ -103,22 +103,6 @@ struct tensor_view
return
m_data
+
this
->
size
();
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
if
(
x
.
m_shape
==
y
.
m_shape
)
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
m_shape
.
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
return
true
;
}
return
false
;
}
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
return
!
(
x
==
y
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
if
(
!
x
.
empty
())
...
...
@@ -137,6 +121,24 @@ struct tensor_view
shape
m_shape
;
};
template
<
class
T
,
class
U
>
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
if
(
x
.
get_shape
()
==
y
.
get_shape
())
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
get_shape
().
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
return
true
;
}
return
false
;
}
template
<
class
T
,
class
U
>
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
U
>&
y
)
{
return
!
(
x
==
y
);
}
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
{
...
...
src/program.cpp
View file @
1b18e9d0
...
...
@@ -112,7 +112,7 @@ void program::compile(const target& t)
RTG_THROW
(
"Invalid program from compilation"
);
}
literal
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
assert
(
this
->
validate
()
!=
impl
->
instructions
.
end
());
std
::
unordered_map
<
const
instruction
*
,
argument
>
results
;
...
...
@@ -142,7 +142,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
}
return
literal
{
result
.
get_shape
(),
result
.
data
()}
;
return
result
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
program
&
p
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment