Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dc21cca1
Commit
dc21cca1
authored
May 30, 2018
by
Paul
Browse files
Formatting
parent
f20f1990
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
12 deletions
+45
-12
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+45
-12
No files found.
src/targets/miopen/miopen_target.cpp
View file @
dc21cca1
...
...
@@ -15,14 +15,13 @@ using convolution_descriptor = RTG_MANAGE_PTR(miopenConvolutionDescriptor_t,
using
activation_descriptor
=
RTG_MANAGE_PTR
(
miopenActivationDescriptor_t
,
miopenDestroyActivationDescriptor
);
template
<
class
Result
,
class
F
,
class
...
Ts
>
template
<
class
Result
,
class
F
,
class
...
Ts
>
Result
make_obj
(
F
f
,
Ts
...
xs
)
{
typename
Result
::
pointer
x
=
nullptr
;
auto
status
=
f
(
&
x
,
xs
...);
Result
r
{
x
};
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
RTG_THROW
(
"MIOpen call failed"
);
return
r
;
}
...
...
@@ -34,8 +33,10 @@ tensor_descriptor make_tensor(const rtg::shape& s)
std
::
vector
<
int
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
std
::
vector
<
int
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
());
miopenDataType_t
d
;
if
(
s
.
type
()
==
shape
::
float_type
)
d
=
miopenFloat
;
else
RTG_THROW
(
"Unsupported type"
);
if
(
s
.
type
()
==
shape
::
float_type
)
d
=
miopenFloat
;
else
RTG_THROW
(
"Unsupported type"
);
miopenSetTensorDescriptor
(
t
.
get
(),
d
,
s
.
lens
().
size
(),
lens
.
data
(),
strides
.
data
());
return
t
;
}
...
...
@@ -43,7 +44,14 @@ tensor_descriptor make_tensor(const rtg::shape& s)
convolution_descriptor
make_conv
(
const
rtg
::
convolution
&
op
)
{
auto
c
=
make_obj
<
convolution_descriptor
>
(
&
miopenCreateConvolutionDescriptor
);
miopenInitConvolutionDescriptor
(
c
.
get
(),
miopenConvolution
,
op
.
padding
[
0
],
op
.
padding
[
1
],
op
.
stride
[
0
],
op
.
stride
[
1
],
op
.
dilation
[
0
],
op
.
dilation
[
1
]);
miopenInitConvolutionDescriptor
(
c
.
get
(),
miopenConvolution
,
op
.
padding
[
0
],
op
.
padding
[
1
],
op
.
stride
[
0
],
op
.
stride
[
1
],
op
.
dilation
[
0
],
op
.
dilation
[
1
]);
return
c
;
}
...
...
@@ -66,8 +74,33 @@ struct miopen_convolution
int
algo_count
;
miopenConvAlgoPerf_t
perf
;
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
data
(),
x_desc
.
get
(),
args
[
1
].
data
(),
w_desc
,
args
[
2
].
data
(),
cd
.
get
(),
y_desc
,
args
[
4
].
data
(),
1
,
&
algo_count
,
&
perf
,
args
[
3
].
data
(),
args
[
3
].
get_shape
().
bytes
(),
false
);
miopenConvolutionForward
(
args
[
0
].
data
(),
&
alpha
,
x_desc
,
args
[
1
].
data
(),
w_desc
,
args
[
2
].
data
(),
cd
.
get
(),
perf
.
fwd_algo
,
&
beta
,
y_desc
,
args
[
4
].
data
(),
args
[
3
].
data
(),
args
[
3
].
get_shape
().
bytes
());
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
data
(),
x_desc
.
get
(),
args
[
1
].
data
(),
w_desc
,
args
[
2
].
data
(),
cd
.
get
(),
y_desc
,
args
[
4
].
data
(),
1
,
&
algo_count
,
&
perf
,
args
[
3
].
data
(),
args
[
3
].
get_shape
().
bytes
(),
false
);
miopenConvolutionForward
(
args
[
0
].
data
(),
&
alpha
,
x_desc
,
args
[
1
].
data
(),
w_desc
,
args
[
2
].
data
(),
cd
.
get
(),
perf
.
fwd_algo
,
&
beta
,
y_desc
,
args
[
4
].
data
(),
args
[
3
].
data
(),
args
[
3
].
get_shape
().
bytes
());
return
result
;
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment