Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3d03e158
Commit
3d03e158
authored
Jun 13, 2018
by
Paul
Browse files
IMprove implicit casting and add a method to extract a tensor view directly
parent
9a24414f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
14 deletions
+29
-14
src/include/rtg/raw_data.hpp
src/include/rtg/raw_data.hpp
+18
-3
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+11
-11
No files found.
src/include/rtg/raw_data.hpp
View file @
3d03e158
...
@@ -59,6 +59,7 @@ struct raw_data : raw_data_base
...
@@ -59,6 +59,7 @@ struct raw_data : raw_data_base
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
}
}
/// Returns true if the raw data is only one element
bool
single
()
const
bool
single
()
const
{
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
...
@@ -86,17 +87,31 @@ struct raw_data : raw_data_base
...
@@ -86,17 +87,31 @@ struct raw_data : raw_data_base
template
<
class
T
>
template
<
class
T
>
operator
T
()
operator
T
()
{
{
assert
(
self
->
single
());
return
self
->
template
at
<
T
>();
return
self
->
template
at
<
T
>();
}
}
template
<
class
T
>
template
<
class
T
>
operator
T
*
()
operator
T
*
()
{
{
// TODO: Check type
using
type
=
std
::
remove_cv_t
<
T
>
;
return
reinterpret_cast
<
T
*>
(
self
->
data
());
assert
((
std
::
is_void
<
T
>
{}
or
std
::
is_same
<
char
,
type
>
{}
or
std
::
is_same
<
unsigned
char
,
type
>
{}
or
self
->
get_shape
().
type
()
==
rtg
::
shape
::
get_type
<
T
>
{}));
return
reinterpret_cast
<
type
*>
(
self
->
data
());
}
}
};
};
auto_cast
get
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
/// Implicit conversion of raw data pointer
auto_cast
implicit
()
const
{
return
{
static_cast
<
const
Derived
*>
(
this
)};
}
/// Get a tensor_view to the data
template
<
class
T
>
tensor_view
<
T
>
get
()
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
if
(
s
.
type
()
!=
rtg
::
shape
::
get_type
<
T
>
{})
RTG_THROW
(
"Incorrect data type for raw data"
);
return
make_view
(
s
,
reinterpret_cast
<
T
*>
(
buffer
));
}
};
};
template
<
class
T
,
template
<
class
T
,
...
...
src/targets/miopen/miopen_target.cpp
View file @
3d03e158
...
@@ -115,31 +115,31 @@ struct miopen_convolution
...
@@ -115,31 +115,31 @@ struct miopen_convolution
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
int
algo_count
;
int
algo_count
;
miopenConvAlgoPerf_t
perf
;
miopenConvAlgoPerf_t
perf
;
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
ge
t
(),
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
implici
t
(),
x_desc
.
get
(),
x_desc
.
get
(),
args
[
1
].
ge
t
(),
args
[
1
].
implici
t
(),
w_desc
.
get
(),
w_desc
.
get
(),
args
[
2
].
ge
t
(),
args
[
2
].
implici
t
(),
cd
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
y_desc
.
get
(),
args
[
3
].
ge
t
(),
args
[
3
].
implici
t
(),
1
,
1
,
&
algo_count
,
&
algo_count
,
&
perf
,
&
perf
,
nullptr
,
nullptr
,
0
,
0
,
false
);
false
);
miopenConvolutionForward
(
args
[
0
].
ge
t
(),
miopenConvolutionForward
(
args
[
0
].
implici
t
(),
&
alpha
,
&
alpha
,
x_desc
.
get
(),
x_desc
.
get
(),
args
[
1
].
ge
t
(),
args
[
1
].
implici
t
(),
w_desc
.
get
(),
w_desc
.
get
(),
args
[
2
].
ge
t
(),
args
[
2
].
implici
t
(),
cd
.
get
(),
cd
.
get
(),
perf
.
fwd_algo
,
perf
.
fwd_algo
,
&
beta
,
&
beta
,
y_desc
.
get
(),
y_desc
.
get
(),
args
[
3
].
ge
t
(),
args
[
3
].
implici
t
(),
nullptr
,
nullptr
,
0
);
0
);
return
args
[
3
];
return
args
[
3
];
...
@@ -161,14 +161,14 @@ struct miopen_relu
...
@@ -161,14 +161,14 @@ struct miopen_relu
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
args
[
0
].
ge
t
(),
miopenActivationForward
(
args
[
0
].
implici
t
(),
ad
.
get
(),
ad
.
get
(),
&
alpha
,
&
alpha
,
x_desc
.
get
(),
x_desc
.
get
(),
args
[
1
].
ge
t
(),
args
[
1
].
implici
t
(),
&
beta
,
&
beta
,
y_desc
.
get
(),
y_desc
.
get
(),
args
[
2
].
ge
t
());
args
[
2
].
implici
t
());
return
args
[
2
];
return
args
[
2
];
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment