Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8e0fff81
Commit
8e0fff81
authored
Jun 29, 2018
by
Paul
Browse files
Store handle in context
parent
43ae3419
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
166 additions
and
69 deletions
+166
-69
src/include/rtg/context.hpp
src/include/rtg/context.hpp
+15
-0
src/include/rtg/operation.hpp
src/include/rtg/operation.hpp
+15
-0
src/include/rtg/operators.hpp
src/include/rtg/operators.hpp
+41
-14
src/include/rtg/target.hpp
src/include/rtg/target.hpp
+15
-0
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+1
-1
src/targets/miopen/include/rtg/miopen/miopen_target.hpp
src/targets/miopen/include/rtg/miopen/miopen_target.hpp
+1
-1
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+62
-50
test/miopen/miopen.cpp
test/miopen/miopen.cpp
+0
-2
tools/te.py
tools/te.py
+16
-1
No files found.
src/include/rtg/context.hpp
View file @
8e0fff81
...
...
@@ -62,6 +62,14 @@ struct context
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -111,13 +119,20 @@ struct context
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
...
...
src/include/rtg/operation.hpp
View file @
8e0fff81
...
...
@@ -84,6 +84,14 @@ struct operation
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
string
name
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -183,13 +191,20 @@ struct operation
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
...
...
src/include/rtg/operators.hpp
View file @
8e0fff81
...
...
@@ -12,14 +12,24 @@ namespace rtg {
struct
check_shapes
{
const
std
::
vector
<
shape
>*
shapes
;
const
std
::
string
name
;
check_shapes
(
const
std
::
vector
<
shape
>&
s
)
:
shapes
(
&
s
)
{}
template
<
class
Op
>
check_shapes
(
const
std
::
vector
<
shape
>&
s
,
const
Op
&
op
)
:
shapes
(
&
s
),
name
(
op
.
name
())
{}
std
::
string
prefix
()
const
{
if
(
name
.
empty
())
return
""
;
else
return
name
+
": "
;
}
const
check_shapes
&
has
(
std
::
size_t
n
)
const
{
assert
(
shapes
!=
nullptr
);
if
(
shapes
->
size
()
!=
n
)
RTG_THROW
(
"Wrong number of arguments: expected "
+
std
::
to_string
(
n
)
+
" but given "
+
RTG_THROW
(
prefix
()
+
"Wrong number of arguments: expected "
+
std
::
to_string
(
n
)
+
" but given "
+
std
::
to_string
(
shapes
->
size
()));
return
*
this
;
}
...
...
@@ -30,7 +40,7 @@ struct check_shapes
if
(
!
shapes
->
empty
())
{
if
(
shapes
->
front
().
lens
().
size
()
!=
n
)
RTG_THROW
(
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
RTG_THROW
(
prefix
()
+
"Only "
+
std
::
to_string
(
n
)
+
"d supported"
);
}
return
*
this
;
}
...
...
@@ -38,28 +48,28 @@ struct check_shapes
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
RTG_THROW
(
"Shapes do not match"
);
RTG_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
}
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
RTG_THROW
(
"Types do not match"
);
RTG_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
}
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
();
}))
RTG_THROW
(
"Dimensions do not match"
);
RTG_THROW
(
prefix
()
+
"Dimensions do not match"
);
return
*
this
;
}
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
lens
().
size
();
}))
RTG_THROW
(
"Dimensions do not match"
);
RTG_THROW
(
prefix
()
+
"Dimensions do not match"
);
return
*
this
;
}
...
...
@@ -101,7 +111,7 @@ struct convolution
std
::
string
name
()
const
{
return
"convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
...
...
@@ -175,7 +185,7 @@ struct pooling
std
::
string
name
()
const
{
return
"pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
).
only_dims
(
4
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
only_dims
(
4
);
const
shape
&
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
...
...
@@ -219,7 +229,7 @@ struct activation
std
::
string
name
()
const
{
return
"activation"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
.
front
();
}
...
...
@@ -237,7 +247,7 @@ struct transpose
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
...
...
@@ -269,7 +279,7 @@ struct contiguous
std
::
string
name
()
const
{
return
"contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
lens
=
inputs
.
at
(
0
).
lens
();
auto
t
=
inputs
.
at
(
0
).
type
();
if
(
lens
.
size
()
<
2
)
...
...
@@ -287,7 +297,7 @@ struct reshape
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
...
...
@@ -322,7 +332,7 @@ struct gemm
std
::
string
name
()
const
{
return
"gemm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
();
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
();
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
...
...
@@ -492,12 +502,29 @@ struct outline
std
::
string
name
()
const
{
return
"outline"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
0
);
check_shapes
{
inputs
,
*
this
}.
has
(
0
);
return
s
;
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
)
const
{
return
{
s
,
nullptr
};
}
};
template
<
class
T
>
struct
check_context
{
std
::
string
name
()
const
{
return
"check_context"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
return
{};
}
argument
compute
(
context
&
ctx
,
shape
,
std
::
vector
<
argument
>
)
const
{
T
*
x
=
any_cast
<
T
>
(
&
ctx
);
if
(
x
==
nullptr
)
RTG_THROW
(
std
::
string
(
"Unexpected context type: "
)
+
ctx
.
type_id
().
name
());
return
{};
}
};
}
// namespace rtg
#endif
src/include/rtg/target.hpp
View file @
8e0fff81
...
...
@@ -73,6 +73,14 @@ struct target
:
nullptr
;
}
const
std
::
type_info
&
type_id
()
const
{
if
(
private_detail_te_handle_empty
())
return
typeid
(
std
::
nullptr_t
);
else
return
private_detail_te_get_handle
().
type
();
}
std
::
string
name
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -150,13 +158,20 @@ struct target
}
};
bool
private_detail_te_handle_empty
()
const
{
return
private_detail_te_handle_mem_var
==
nullptr
;
}
const
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
const
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
return
*
private_detail_te_handle_mem_var
;
}
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
...
...
src/onnx/verify_onnx.cpp
View file @
8e0fff81
...
...
@@ -30,7 +30,7 @@ rtg::argument run_gpu(std::string file)
auto
handle
=
rtg
::
miopen
::
make_obj
<
rtg
::
miopen
::
miopen_handle
>
(
&
miopenCreate
);
auto
out
=
p
.
eval
(
{{
"Input3"
,
input3
},
{
"handle"
,
{
rtg
::
shape
::
any_type
,
handle
.
get
()}},
{
"output"
,
output
}});
{{
"Input3"
,
input3
},
{
"output"
,
output
}});
std
::
cout
<<
p
<<
std
::
endl
;
return
rtg
::
miopen
::
from_gpu
(
out
);
}
...
...
src/targets/miopen/include/rtg/miopen/miopen_target.hpp
View file @
8e0fff81
...
...
@@ -10,7 +10,7 @@ struct miopen_target
{
std
::
string
name
()
const
;
void
apply
(
program
&
p
)
const
;
context
get_context
()
const
{
return
{};
}
context
get_context
()
const
;
};
}
// namespace miopen
...
...
src/targets/miopen/miopen_target.cpp
View file @
8e0fff81
...
...
@@ -10,6 +10,11 @@
namespace
rtg
{
namespace
miopen
{
struct
miopen_context
{
shared
<
miopen_handle
>
handle
;
};
struct
miopen_convolution
{
convolution
op
;
...
...
@@ -18,46 +23,47 @@ struct miopen_convolution
std
::
string
name
()
const
{
return
"miopen::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
4
);
return
op
.
compute_shape
({
inputs
.
at
(
1
),
inputs
.
at
(
2
)});
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
x_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
2
].
get_shape
());
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
float
alpha
=
1
,
beta
=
0
;
int
algo_count
;
miopenConvAlgoPerf_t
perf
;
miopenFindConvolutionForwardAlgorithm
(
args
[
0
].
implici
t
(),
miopenFindConvolutionForwardAlgorithm
(
ctx
.
handle
.
ge
t
(),
x_desc
.
get
(),
args
[
1
].
implicit
(),
args
[
0
].
implicit
(),
w_desc
.
get
(),
args
[
2
].
implicit
(),
args
[
1
].
implicit
(),
cd
.
get
(),
y_desc
.
get
(),
args
[
3
].
implicit
(),
args
[
2
].
implicit
(),
1
,
&
algo_count
,
&
perf
,
nullptr
,
0
,
false
);
miopenConvolutionForward
(
args
[
0
].
implici
t
(),
miopenConvolutionForward
(
ctx
.
handle
.
ge
t
(),
&
alpha
,
x_desc
.
get
(),
args
[
1
].
implicit
(),
args
[
0
].
implicit
(),
w_desc
.
get
(),
args
[
2
].
implicit
(),
args
[
1
].
implicit
(),
cd
.
get
(),
perf
.
fwd_algo
,
&
beta
,
y_desc
.
get
(),
args
[
3
].
implicit
(),
args
[
2
].
implicit
(),
nullptr
,
0
);
return
args
[
3
];
return
args
[
2
];
}
};
...
...
@@ -69,29 +75,30 @@ struct miopen_pooling
std
::
string
name
()
const
{
return
"miopen::pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
x_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
float
alpha
=
1
,
beta
=
0
;
miopenPoolingForward
(
args
[
0
].
implici
t
(),
miopenPoolingForward
(
ctx
.
handle
.
ge
t
(),
pd
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
1
].
implicit
(),
args
[
0
].
implicit
(),
&
beta
,
y_desc
.
get
(),
args
[
2
].
implicit
(),
args
[
1
].
implicit
(),
false
,
nullptr
,
0
);
return
args
[
2
];
return
args
[
1
];
}
};
...
...
@@ -100,17 +107,17 @@ struct miopen_add
std
::
string
name
()
const
{
return
"miopen::add"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
4
);
return
inputs
.
at
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
if
(
args
[
2
].
get_shape
().
broadcasted
())
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
argument
result
{
output_shape
};
visit_all
(
result
,
from_gpu
(
args
[
1
]),
from_gpu
(
args
[
2
]))(
visit_all
(
result
,
from_gpu
(
args
[
0
]),
from_gpu
(
args
[
1
]))(
[
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
...
...
@@ -121,22 +128,23 @@ struct miopen_add
}
else
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
a_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
2
].
get_shape
());
auto
a_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
b_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
c_desc
=
make_tensor
(
output_shape
);
miopenOpTensor
(
args
[
0
].
implici
t
(),
miopenOpTensor
(
ctx
.
handle
.
ge
t
(),
miopenTensorOpAdd
,
&
alpha
,
a_desc
.
get
(),
args
[
1
].
implicit
(),
args
[
0
].
implicit
(),
&
alpha
,
b_desc
.
get
(),
args
[
2
].
implicit
(),
args
[
1
].
implicit
(),
&
beta
,
c_desc
.
get
(),
args
[
3
].
implicit
());
return
args
[
3
];
args
[
2
].
implicit
());
return
args
[
2
];
}
}
};
...
...
@@ -147,14 +155,14 @@ struct miopen_gemm
std
::
string
name
()
const
{
return
"miopen::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
4
);
return
op
.
compute_shape
({
inputs
.
at
(
1
),
inputs
.
at
(
2
)});
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
from_gpu
(
args
[
1
]),
from_gpu
(
args
[
2
]))(
visit_all
(
result
,
from_gpu
(
args
[
0
]),
from_gpu
(
args
[
1
]))(
[
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
dfor
(
input1
.
get_shape
().
lens
()[
0
],
input2
.
get_shape
().
lens
()[
1
],
...
...
@@ -171,36 +179,36 @@ struct miopen_relu
std
::
string
name
()
const
{
return
"miopen::relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
inputs
.
at
(
1
);
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
args
[
0
].
implici
t
(),
miopenActivationForward
(
ctx
.
handle
.
ge
t
(),
ad
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
1
].
implicit
(),
args
[
0
].
implicit
(),
&
beta
,
y_desc
.
get
(),
args
[
2
].
implicit
());
args
[
1
].
implicit
());
return
args
[
2
];
return
args
[
1
];
}
};
struct
miopen_apply
{
program
*
prog
=
nullptr
;
instruction_ref
handle
{};
void
apply
()
{
handle
=
prog
->
add_parameter
(
"handle"
,
shape
{
shape
::
any_type
});
prog
->
insert_instruction
(
prog
->
begin
(),
check_context
<
miopen_context
>
{
});
for
(
auto
it
=
prog
->
begin
();
it
!=
prog
->
end
();
it
++
)
{
if
(
it
->
op
.
name
()
==
"convolution"
)
...
...
@@ -248,7 +256,6 @@ struct miopen_apply
prog
->
replace_instruction
(
ins
,
miopen_convolution
{
op
,
std
::
move
(
cd
)},
handle
,
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
...
...
@@ -261,7 +268,7 @@ struct miopen_apply
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
handle
,
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
arguments
.
at
(
0
),
output
);
}
void
apply_activation
(
instruction_ref
ins
)
...
...
@@ -272,7 +279,7 @@ struct miopen_apply
{
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_relu
{
std
::
move
(
ad
)},
handle
,
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
arguments
.
at
(
0
),
output
);
}
}
...
...
@@ -280,7 +287,7 @@ struct miopen_apply
{
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_add
{},
handle
,
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
miopen_add
{},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
}
void
apply_gemm
(
instruction_ref
ins
)
...
...
@@ -288,7 +295,7 @@ struct miopen_apply
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_gemm
{
op
},
handle
,
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
}
};
...
...
@@ -296,6 +303,11 @@ std::string miopen_target::name() const { return "miopen"; }
void
miopen_target
::
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
context
miopen_target
::
get_context
()
const
{
return
miopen_context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
))};
}
}
// namespace miopen
}
// namespace rtg
test/miopen/miopen.cpp
View file @
8e0fff81
...
...
@@ -36,8 +36,6 @@ rtg::argument run_gpu()
}
m
[
"output"
]
=
rtg
::
miopen
::
to_gpu
(
rtg
::
generate_argument
(
p
.
get_parameter_shape
(
"output"
)));
auto
handle
=
rtg
::
miopen
::
make_obj
<
rtg
::
miopen
::
miopen_handle
>
(
&
miopenCreate
);
m
[
"handle"
]
=
{
rtg
::
shape
::
any_type
,
handle
.
get
()};
return
rtg
::
miopen
::
from_gpu
(
p
.
eval
(
m
));
}
...
...
tools/te.py
View file @
8e0fff81
...
...
@@ -63,6 +63,12 @@ struct ${struct_name}
nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty()) return typeid(std::nullptr_t);
else return private_detail_te_get_handle().type();
}
${nonvirtual_members}
private:
...
...
@@ -118,11 +124,20 @@ private:
{}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type & private_detail_te_get_handle () const
{ return *private_detail_te_handle_mem_var; }
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type & private_detail_te_get_handle ()
{
assert(private_detail_te_handle_mem_var != nullptr);
if (!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment