Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9a3fc32d
"docs/en_US/vscode:/vscode.git/clone" did not exist on "0fd38debd5c6567e22bd38d2752d7a34306d7c5f"
Commit
9a3fc32d
authored
Aug 18, 2018
by
Paul
Browse files
Use const refs where possible
parent
61991b42
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
41 deletions
+43
-41
src/targets/gpu/include/migraph/gpu/hip.hpp
src/targets/gpu/include/migraph/gpu/hip.hpp
+6
-5
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+18
-17
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+6
-6
test/include/test.hpp
test/include/test.hpp
+1
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+3
-3
test/operation.cpp
test/operation.cpp
+4
-4
tools/include/operation.hpp
tools/include/operation.hpp
+5
-5
No files found.
src/targets/gpu/include/migraph/gpu/hip.hpp
View file @
9a3fc32d
...
@@ -2,11 +2,12 @@
...
@@ -2,11 +2,12 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraph/operators.hpp>
#include <migraph/operators.hpp>
#include <utility>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
migraph
::
argument
allocate_gpu
(
migraph
::
shape
s
,
bool
host
=
false
);
migraph
::
argument
allocate_gpu
(
const
migraph
::
shape
&
s
,
bool
host
=
false
);
migraph
::
argument
to_gpu
(
migraph
::
argument
arg
,
bool
host
=
false
);
migraph
::
argument
to_gpu
(
migraph
::
argument
arg
,
bool
host
=
false
);
...
@@ -16,12 +17,12 @@ struct hip_allocate
...
@@ -16,12 +17,12 @@ struct hip_allocate
{
{
std
::
string
tag
{};
std
::
string
tag
{};
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
)
const
{
{
return
allocate_gpu
(
output_shape
);
return
allocate_gpu
(
output_shape
);
}
}
...
@@ -30,12 +31,12 @@ struct hip_allocate
...
@@ -30,12 +31,12 @@ struct hip_allocate
struct
hip_write
struct
hip_write
{
{
std
::
string
name
()
const
{
return
"hip::write"
;
}
std
::
string
name
()
const
{
return
"hip::write"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
return
to_gpu
(
args
.
front
());
return
to_gpu
(
args
.
front
());
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
9a3fc32d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <migraph/iterator_for.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/context.hpp>
#include <utility>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -22,14 +23,14 @@ struct miopen_batch_norm_inference
...
@@ -22,14 +23,14 @@ struct miopen_batch_norm_inference
std
::
string
name
()
const
{
return
"gpu::batch_norm_inference"
;
}
std
::
string
name
()
const
{
return
"gpu::batch_norm_inference"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
6
);
check_shapes
{
inputs
,
*
this
}.
has
(
6
);
return
op
.
compute_shape
(
return
op
.
compute_shape
(
{
inputs
.
at
(
0
),
inputs
.
at
(
1
),
inputs
.
at
(
2
),
inputs
.
at
(
3
),
inputs
.
at
(
4
)});
{
inputs
.
at
(
0
),
inputs
.
at
(
1
),
inputs
.
at
(
2
),
inputs
.
at
(
3
),
inputs
.
at
(
4
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -63,12 +64,12 @@ struct miopen_convolution
...
@@ -63,12 +64,12 @@ struct miopen_convolution
miopenConvFwdAlgorithm_t
algo
{};
miopenConvFwdAlgorithm_t
algo
{};
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
@@ -91,7 +92,7 @@ struct miopen_convolution
...
@@ -91,7 +92,7 @@ struct miopen_convolution
return
args
[
3
];
return
args
[
3
];
}
}
shape
compile
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
instruction_ref
>
inputs
)
shape
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
instruction_ref
>
inputs
)
{
{
shape
workspace_shape
{};
shape
workspace_shape
{};
auto
x_desc
=
make_tensor
(
inputs
[
0
]
->
get_shape
());
auto
x_desc
=
make_tensor
(
inputs
[
0
]
->
get_shape
());
...
@@ -136,12 +137,12 @@ struct miopen_pooling
...
@@ -136,12 +137,12 @@ struct miopen_pooling
shared
<
pooling_descriptor
>
pd
;
shared
<
pooling_descriptor
>
pd
;
std
::
string
name
()
const
{
return
"gpu::pooling"
;
}
std
::
string
name
()
const
{
return
"gpu::pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -167,13 +168,13 @@ struct miopen_pooling
...
@@ -167,13 +168,13 @@ struct miopen_pooling
struct
miopen_add
struct
miopen_add
{
{
std
::
string
name
()
const
{
return
"gpu::add"
;
}
std
::
string
name
()
const
{
return
"gpu::add"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
not_broadcasted
();
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
not_broadcasted
();
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
{
...
@@ -214,12 +215,12 @@ struct miopen_gemm
...
@@ -214,12 +215,12 @@ struct miopen_gemm
{
{
gemm
op
;
gemm
op
;
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
...
@@ -253,14 +254,14 @@ struct miopen_contiguous
...
@@ -253,14 +254,14 @@ struct miopen_contiguous
{
{
contiguous
op
;
contiguous
op
;
std
::
string
name
()
const
{
return
"gpu::contiguous"
;
}
std
::
string
name
()
const
{
return
"gpu::contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
hip_contiguous
(
output_shape
,
args
.
at
(
0
),
args
.
at
(
1
));
hip_contiguous
(
std
::
move
(
output_shape
)
,
args
.
at
(
0
),
args
.
at
(
1
));
return
args
.
at
(
1
);
return
args
.
at
(
1
);
}
}
};
};
...
@@ -269,13 +270,13 @@ struct miopen_relu
...
@@ -269,13 +270,13 @@ struct miopen_relu
{
{
shared
<
activation_descriptor
>
ad
;
shared
<
activation_descriptor
>
ad
;
std
::
string
name
()
const
{
return
"gpu::relu"
;
}
std
::
string
name
()
const
{
return
"gpu::relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
not_broadcasted
();
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
not_broadcasted
();
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
...
@@ -350,7 +351,7 @@ struct miopen_apply
...
@@ -350,7 +351,7 @@ struct miopen_apply
else
else
{
{
auto
is
=
prog
->
add_outline
(
s
);
auto
is
=
prog
->
add_outline
(
s
);
auto
result
=
prog
->
insert_instruction
(
ins
,
hip_allocate
{
tag
},
is
);
auto
result
=
prog
->
insert_instruction
(
ins
,
hip_allocate
{
std
::
move
(
tag
)
},
is
);
return
result
;
return
result
;
}
}
}
}
...
...
test/include/basic_ops.hpp
View file @
9a3fc32d
...
@@ -6,7 +6,7 @@ struct sum_op
...
@@ -6,7 +6,7 @@ struct sum_op
{
{
std
::
string
name
()
const
{
return
"sum"
;
}
std
::
string
name
()
const
{
return
"sum"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
migraph
::
argument
result
;
migraph
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
@@ -36,7 +36,7 @@ struct minus_op
...
@@ -36,7 +36,7 @@ struct minus_op
{
{
std
::
string
name
()
const
{
return
"minus"
;
}
std
::
string
name
()
const
{
return
"minus"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
migraph
::
argument
result
;
migraph
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
@@ -66,7 +66,7 @@ struct pass_op
...
@@ -66,7 +66,7 @@ struct pass_op
{
{
std
::
string
name
()
const
{
return
"pass"
;
}
std
::
string
name
()
const
{
return
"pass"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
if
(
args
.
empty
())
if
(
args
.
empty
())
return
{};
return
{};
...
@@ -85,7 +85,7 @@ struct pass_standard_op
...
@@ -85,7 +85,7 @@ struct pass_standard_op
{
{
std
::
string
name
()
const
{
return
"pass"
;
}
std
::
string
name
()
const
{
return
"pass"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
if
(
args
.
empty
())
if
(
args
.
empty
())
return
{};
return
{};
...
@@ -109,12 +109,12 @@ struct nop
...
@@ -109,12 +109,12 @@ struct nop
{
{
std
::
string
name
()
const
{
return
"nop"
;
}
std
::
string
name
()
const
{
return
"nop"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
const
std
::
vector
<
migraph
::
argument
>
&
)
const
{
{
return
{};
return
{};
}
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
{
return
{};
}
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>
&
)
const
{
return
{};
}
};
};
inline
migraph
::
literal
get_2x2
()
inline
migraph
::
literal
get_2x2
()
...
...
test/include/test.hpp
View file @
9a3fc32d
...
@@ -141,7 +141,7 @@ bool throws(F f)
...
@@ -141,7 +141,7 @@ bool throws(F f)
}
}
template
<
class
F
,
class
Exception
>
template
<
class
F
,
class
Exception
>
bool
throws
(
F
f
,
std
::
string
msg
=
""
)
bool
throws
(
F
f
,
const
std
::
string
&
msg
=
""
)
{
{
try
try
{
{
...
...
test/op_shape_test.cpp
View file @
9a3fc32d
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "test.hpp"
#include "test.hpp"
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
expect_shape
(
migraph
::
shape
expected
,
migraph
::
operation
op
,
Ts
...
xs
)
void
expect_shape
(
const
migraph
::
shape
&
expected
,
const
migraph
::
operation
&
op
,
Ts
...
xs
)
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
...
@@ -24,7 +24,7 @@ void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs)
...
@@ -24,7 +24,7 @@ void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs)
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
throws_shape
(
migraph
::
operation
op
,
Ts
...
xs
)
void
throws_shape
(
const
migraph
::
operation
&
op
,
Ts
...
xs
)
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
...
@@ -46,7 +46,7 @@ struct always_false : std::false_type
...
@@ -46,7 +46,7 @@ struct always_false : std::false_type
};
};
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
throws_shape
(
migraph
::
shape
,
Ts
...)
void
throws_shape
(
const
migraph
::
shape
&
,
Ts
...)
{
{
static_assert
(
always_false
<
Ts
...
>
{},
static_assert
(
always_false
<
Ts
...
>
{},
"An expected shape should not be passed to throws_shape function"
);
"An expected shape should not be passed to throws_shape function"
);
...
...
test/operation.cpp
View file @
9a3fc32d
...
@@ -8,12 +8,12 @@ struct simple_operation
...
@@ -8,12 +8,12 @@ struct simple_operation
{
{
int
data
=
1
;
int
data
=
1
;
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
const
std
::
vector
<
migraph
::
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -27,12 +27,12 @@ struct simple_operation
...
@@ -27,12 +27,12 @@ struct simple_operation
struct
simple_operation_no_print
struct
simple_operation_no_print
{
{
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
const
std
::
vector
<
migraph
::
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
...
tools/include/operation.hpp
View file @
9a3fc32d
...
@@ -25,7 +25,7 @@ struct operation
...
@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
/// exception.
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
input
)
const
;
/**
/**
* @brief This performs the operation's computation
* @brief This performs the operation's computation
*
*
...
@@ -37,7 +37,7 @@ struct operation
...
@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
* the same the `output` shape.
*/
*/
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>
&
input
)
const
;
/// An optional stream operator to print the operation. When this is not
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
/// implemented, it will just print the operation's name.
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
...
@@ -56,7 +56,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -56,7 +56,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
input
)
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
...
@@ -64,8 +64,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
...
@@ -64,8 +64,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
<%
<%
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
const
std
::
vector
<
shape
>
&
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>
&
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment