Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3d03e158
Commit
3d03e158
authored
Jun 13, 2018
by
Paul
Browse files
IMprove implicit casting and add a method to extract a tensor view directly
parent
9a24414f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
14 deletions
+29
-14
src/include/rtg/raw_data.hpp
src/include/rtg/raw_data.hpp
+18
-3
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+11
-11
No files found.
src/include/rtg/raw_data.hpp
View file @
3d03e158
...
...
@@ -59,6 +59,7 @@ struct raw_data : raw_data_base
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
}
/// Returns true if the raw data is only one element
bool
single
()
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
...
...
@@ -86,17 +87,31 @@ struct raw_data : raw_data_base
template
<
class
T
>
operator
T
()
{
assert
(
self
->
single
());
return
self
->
template
at
<
T
>();
}
template
<
class
T
>
operator
T
*
()
{
// TODO: Check type
return
reinterpret_cast
<
T
*>
(
self
->
data
());
using
type
=
std
::
remove_cv_t
<
T
>
;
assert
((
std
::
is_void
<
T
>
{}
or
std
::
is_same
<
char
,
type
>
{}
or
std
::
is_same
<
unsigned
char
,
type
>
{}
or
self
->
get_shape
().
type
()
==
rtg
::
shape
::
get_type
<
T
>
{}));
return
reinterpret_cast
<
type
*>
(
self
->
data
());
}
};
auto_cast
get
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
/// Implicit conversion of raw data pointer
auto_cast
implicit
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
/// Get a tensor_view to the data
template
<
class
T
>
tensor_view
<
T
>
get
()
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
if
(
s
.
type
()
!=
rtg
::
shape
::
get_type
<
T
>
{})
RTG_THROW
(
"Incorrect data type for raw data"
);
return
make_view
(
s
,
reinterpret_cast
<
T
*>
(
buffer
));
}
};
template
<
class
T
,
...
...
src/targets/miopen/miopen_target.cpp
View file @
3d03e158
...
...
@@ -115,31 +115,31 @@ struct miopen_convolution
float
alpha
=
1
,
beta
=
0
;
int
algo_count
;
miopenConvAlgoPerf_t
perf
;
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
ge
t
(),
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
implici
t
(),
x_desc
.
get
(),
args
[
1
].
ge
t
(),
args
[
1
].
implici
t
(),
w_desc
.
get
(),
args
[
2
].
ge
t
(),
args
[
2
].
implici
t
(),
cd
.
get
(),
y_desc
.
get
(),
args
[
3
].
ge
t
(),
args
[
3
].
implici
t
(),
1
,
&
algo_count
,
&
perf
,
nullptr
,
0
,
false
);
miopenConvolutionForward
(
args
[
0
].
ge
t
(),
miopenConvolutionForward
(
args
[
0
].
implici
t
(),
&
alpha
,
x_desc
.
get
(),
args
[
1
].
ge
t
(),
args
[
1
].
implici
t
(),
w_desc
.
get
(),
args
[
2
].
ge
t
(),
args
[
2
].
implici
t
(),
cd
.
get
(),
perf
.
fwd_algo
,
&
beta
,
y_desc
.
get
(),
args
[
3
].
ge
t
(),
args
[
3
].
implici
t
(),
nullptr
,
0
);
return
args
[
3
];
...
...
@@ -161,14 +161,14 @@ struct miopen_relu
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
args
[
0
].
ge
t
(),
miopenActivationForward
(
args
[
0
].
implici
t
(),
ad
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
1
].
ge
t
(),
args
[
1
].
implici
t
(),
&
beta
,
y_desc
.
get
(),
args
[
2
].
ge
t
());
args
[
2
].
implici
t
());
return
args
[
2
];
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment