Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9a3fc32d
Commit
9a3fc32d
authored
Aug 18, 2018
by
Paul
Browse files
Use const refs where possible
parent
61991b42
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
41 deletions
+43
-41
src/targets/gpu/include/migraph/gpu/hip.hpp
src/targets/gpu/include/migraph/gpu/hip.hpp
+6
-5
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+18
-17
test/include/basic_ops.hpp
test/include/basic_ops.hpp
+6
-6
test/include/test.hpp
test/include/test.hpp
+1
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+3
-3
test/operation.cpp
test/operation.cpp
+4
-4
tools/include/operation.hpp
tools/include/operation.hpp
+5
-5
No files found.
src/targets/gpu/include/migraph/gpu/hip.hpp
View file @
9a3fc32d
...
@@ -2,11 +2,12 @@
...
@@ -2,11 +2,12 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraph/operators.hpp>
#include <migraph/operators.hpp>
#include <utility>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
migraph
::
argument
allocate_gpu
(
migraph
::
shape
s
,
bool
host
=
false
);
migraph
::
argument
allocate_gpu
(
const
migraph
::
shape
&
s
,
bool
host
=
false
);
migraph
::
argument
to_gpu
(
migraph
::
argument
arg
,
bool
host
=
false
);
migraph
::
argument
to_gpu
(
migraph
::
argument
arg
,
bool
host
=
false
);
...
@@ -16,12 +17,12 @@ struct hip_allocate
...
@@ -16,12 +17,12 @@ struct hip_allocate
{
{
std
::
string
tag
{};
std
::
string
tag
{};
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
std
::
string
name
()
const
{
return
"hip::allocate"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
)
const
{
{
return
allocate_gpu
(
output_shape
);
return
allocate_gpu
(
output_shape
);
}
}
...
@@ -30,12 +31,12 @@ struct hip_allocate
...
@@ -30,12 +31,12 @@ struct hip_allocate
struct
hip_write
struct
hip_write
{
{
std
::
string
name
()
const
{
return
"hip::write"
;
}
std
::
string
name
()
const
{
return
"hip::write"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
context
&
,
shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
return
to_gpu
(
args
.
front
());
return
to_gpu
(
args
.
front
());
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
9a3fc32d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <migraph/iterator_for.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/context.hpp>
#include <utility>
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
...
@@ -22,14 +23,14 @@ struct miopen_batch_norm_inference
...
@@ -22,14 +23,14 @@ struct miopen_batch_norm_inference
std
::
string
name
()
const
{
return
"gpu::batch_norm_inference"
;
}
std
::
string
name
()
const
{
return
"gpu::batch_norm_inference"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
6
);
check_shapes
{
inputs
,
*
this
}.
has
(
6
);
return
op
.
compute_shape
(
return
op
.
compute_shape
(
{
inputs
.
at
(
0
),
inputs
.
at
(
1
),
inputs
.
at
(
2
),
inputs
.
at
(
3
),
inputs
.
at
(
4
)});
{
inputs
.
at
(
0
),
inputs
.
at
(
1
),
inputs
.
at
(
2
),
inputs
.
at
(
3
),
inputs
.
at
(
4
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -63,12 +64,12 @@ struct miopen_convolution
...
@@ -63,12 +64,12 @@ struct miopen_convolution
miopenConvFwdAlgorithm_t
algo
{};
miopenConvFwdAlgorithm_t
algo
{};
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
());
...
@@ -91,7 +92,7 @@ struct miopen_convolution
...
@@ -91,7 +92,7 @@ struct miopen_convolution
return
args
[
3
];
return
args
[
3
];
}
}
shape
compile
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
instruction_ref
>
inputs
)
shape
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
instruction_ref
>
inputs
)
{
{
shape
workspace_shape
{};
shape
workspace_shape
{};
auto
x_desc
=
make_tensor
(
inputs
[
0
]
->
get_shape
());
auto
x_desc
=
make_tensor
(
inputs
[
0
]
->
get_shape
());
...
@@ -136,12 +137,12 @@ struct miopen_pooling
...
@@ -136,12 +137,12 @@ struct miopen_pooling
shared
<
pooling_descriptor
>
pd
;
shared
<
pooling_descriptor
>
pd
;
std
::
string
name
()
const
{
return
"gpu::pooling"
;
}
std
::
string
name
()
const
{
return
"gpu::pooling"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
auto
y_desc
=
make_tensor
(
output_shape
);
...
@@ -167,13 +168,13 @@ struct miopen_pooling
...
@@ -167,13 +168,13 @@ struct miopen_pooling
struct
miopen_add
struct
miopen_add
{
{
std
::
string
name
()
const
{
return
"gpu::add"
;
}
std
::
string
name
()
const
{
return
"gpu::add"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
not_broadcasted
();
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
not_broadcasted
();
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
{
...
@@ -214,12 +215,12 @@ struct miopen_gemm
...
@@ -214,12 +215,12 @@ struct miopen_gemm
{
{
gemm
op
;
gemm
op
;
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
std
::
string
name
()
const
{
return
"gpu::convolution"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
...
@@ -253,14 +254,14 @@ struct miopen_contiguous
...
@@ -253,14 +254,14 @@ struct miopen_contiguous
{
{
contiguous
op
;
contiguous
op
;
std
::
string
name
()
const
{
return
"gpu::contiguous"
;
}
std
::
string
name
()
const
{
return
"gpu::contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
hip_contiguous
(
output_shape
,
args
.
at
(
0
),
args
.
at
(
1
));
hip_contiguous
(
std
::
move
(
output_shape
)
,
args
.
at
(
0
),
args
.
at
(
1
));
return
args
.
at
(
1
);
return
args
.
at
(
1
);
}
}
};
};
...
@@ -269,13 +270,13 @@ struct miopen_relu
...
@@ -269,13 +270,13 @@ struct miopen_relu
{
{
shared
<
activation_descriptor
>
ad
;
shared
<
activation_descriptor
>
ad
;
std
::
string
name
()
const
{
return
"gpu::relu"
;
}
std
::
string
name
()
const
{
return
"gpu::relu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
not_broadcasted
();
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
not_broadcasted
();
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
compute
(
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
args
)
const
{
{
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
...
@@ -350,7 +351,7 @@ struct miopen_apply
...
@@ -350,7 +351,7 @@ struct miopen_apply
else
else
{
{
auto
is
=
prog
->
add_outline
(
s
);
auto
is
=
prog
->
add_outline
(
s
);
auto
result
=
prog
->
insert_instruction
(
ins
,
hip_allocate
{
tag
},
is
);
auto
result
=
prog
->
insert_instruction
(
ins
,
hip_allocate
{
std
::
move
(
tag
)
},
is
);
return
result
;
return
result
;
}
}
}
}
...
...
test/include/basic_ops.hpp
View file @
9a3fc32d
...
@@ -6,7 +6,7 @@ struct sum_op
...
@@ -6,7 +6,7 @@ struct sum_op
{
{
std
::
string
name
()
const
{
return
"sum"
;
}
std
::
string
name
()
const
{
return
"sum"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
migraph
::
argument
result
;
migraph
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
@@ -36,7 +36,7 @@ struct minus_op
...
@@ -36,7 +36,7 @@ struct minus_op
{
{
std
::
string
name
()
const
{
return
"minus"
;
}
std
::
string
name
()
const
{
return
"minus"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
migraph
::
argument
result
;
migraph
::
argument
result
;
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
...
@@ -66,7 +66,7 @@ struct pass_op
...
@@ -66,7 +66,7 @@ struct pass_op
{
{
std
::
string
name
()
const
{
return
"pass"
;
}
std
::
string
name
()
const
{
return
"pass"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
if
(
args
.
empty
())
if
(
args
.
empty
())
return
{};
return
{};
...
@@ -85,7 +85,7 @@ struct pass_standard_op
...
@@ -85,7 +85,7 @@ struct pass_standard_op
{
{
std
::
string
name
()
const
{
return
"pass"
;
}
std
::
string
name
()
const
{
return
"pass"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
args
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
std
::
vector
<
migraph
::
argument
>
args
)
const
{
{
if
(
args
.
empty
())
if
(
args
.
empty
())
return
{};
return
{};
...
@@ -109,12 +109,12 @@ struct nop
...
@@ -109,12 +109,12 @@ struct nop
{
{
std
::
string
name
()
const
{
return
"nop"
;
}
std
::
string
name
()
const
{
return
"nop"
;
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
const
std
::
vector
<
migraph
::
argument
>
&
)
const
{
{
return
{};
return
{};
}
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
{
return
{};
}
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>
&
)
const
{
return
{};
}
};
};
inline
migraph
::
literal
get_2x2
()
inline
migraph
::
literal
get_2x2
()
...
...
test/include/test.hpp
View file @
9a3fc32d
...
@@ -141,7 +141,7 @@ bool throws(F f)
...
@@ -141,7 +141,7 @@ bool throws(F f)
}
}
template
<
class
F
,
class
Exception
>
template
<
class
F
,
class
Exception
>
bool
throws
(
F
f
,
std
::
string
msg
=
""
)
bool
throws
(
F
f
,
const
std
::
string
&
msg
=
""
)
{
{
try
try
{
{
...
...
test/op_shape_test.cpp
View file @
9a3fc32d
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "test.hpp"
#include "test.hpp"
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
expect_shape
(
migraph
::
shape
expected
,
migraph
::
operation
op
,
Ts
...
xs
)
void
expect_shape
(
const
migraph
::
shape
&
expected
,
const
migraph
::
operation
&
op
,
Ts
...
xs
)
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
...
@@ -24,7 +24,7 @@ void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs)
...
@@ -24,7 +24,7 @@ void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs)
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
throws_shape
(
migraph
::
operation
op
,
Ts
...
xs
)
void
throws_shape
(
const
migraph
::
operation
&
op
,
Ts
...
xs
)
{
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
std
::
vector
<
migraph
::
shape
>
shapes
{
xs
...};
...
@@ -46,7 +46,7 @@ struct always_false : std::false_type
...
@@ -46,7 +46,7 @@ struct always_false : std::false_type
};
};
template
<
class
...
Ts
>
template
<
class
...
Ts
>
void
throws_shape
(
migraph
::
shape
,
Ts
...)
void
throws_shape
(
const
migraph
::
shape
&
,
Ts
...)
{
{
static_assert
(
always_false
<
Ts
...
>
{},
static_assert
(
always_false
<
Ts
...
>
{},
"An expected shape should not be passed to throws_shape function"
);
"An expected shape should not be passed to throws_shape function"
);
...
...
test/operation.cpp
View file @
9a3fc32d
...
@@ -8,12 +8,12 @@ struct simple_operation
...
@@ -8,12 +8,12 @@ struct simple_operation
{
{
int
data
=
1
;
int
data
=
1
;
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
const
std
::
vector
<
migraph
::
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
@@ -27,12 +27,12 @@ struct simple_operation
...
@@ -27,12 +27,12 @@ struct simple_operation
struct
simple_operation_no_print
struct
simple_operation_no_print
{
{
std
::
string
name
()
const
{
return
"simple"
;
}
std
::
string
name
()
const
{
return
"simple"
;
}
migraph
::
shape
compute_shape
(
std
::
vector
<
migraph
::
shape
>
)
const
migraph
::
shape
compute_shape
(
const
std
::
vector
<
migraph
::
shape
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
migraph
::
argument
migraph
::
argument
compute
(
migraph
::
context
&
,
migraph
::
shape
,
std
::
vector
<
migraph
::
argument
>
)
const
compute
(
migraph
::
context
&
,
const
migraph
::
shape
&
,
const
std
::
vector
<
migraph
::
argument
>
&
)
const
{
{
MIGRAPH_THROW
(
"not computable"
);
MIGRAPH_THROW
(
"not computable"
);
}
}
...
...
tools/include/operation.hpp
View file @
9a3fc32d
...
@@ -25,7 +25,7 @@ struct operation
...
@@ -25,7 +25,7 @@ struct operation
/// This is used to compute the resulting shape from an operation. If an
/// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an
/// operation cannot be run with input shapes, then it should throw an
/// exception.
/// exception.
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
input
)
const
;
/**
/**
* @brief This performs the operation's computation
* @brief This performs the operation's computation
*
*
...
@@ -37,7 +37,7 @@ struct operation
...
@@ -37,7 +37,7 @@ struct operation
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
* the same the `output` shape.
*/
*/
argument
compute
(
context
&
ctx
,
shape
output
,
std
::
vector
<
argument
>
input
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>
&
input
)
const
;
/// An optional stream operator to print the operation. When this is not
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
/// implemented, it will just print the operation's name.
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
operation
&
op
);
...
@@ -56,7 +56,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -56,7 +56,7 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
shape
output_shape
,
std
::
vector
<
argument
>
input
)
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>
&
input
)
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
}
...
@@ -64,8 +64,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
...
@@ -64,8 +64,8 @@ argument compute_op(const T& x, context& ctx, shape output_shape, std::vector<ar
<%
<%
interface
(
'
operation
'
,
interface
(
'
operation
'
,
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
name
'
,
returns
=
'
std
::
string
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
std
::
vector
<
shape
>
'
,
const
=
True
),
virtual
(
'
compute_shape
'
,
returns
=
'
shape
'
,
input
=
'
const
std
::
vector
<
shape
>
&
'
,
const
=
True
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
shape
'
,
input
=
'
std
::
vector
<
argument
>
'
,
const
=
True
,
default
=
'
compute_op
'
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>
&
'
,
const
=
True
,
default
=
'
compute_op
'
),
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
friend
(
'
operator
<<
'
,
returns
=
'
std
::
ostream
&
'
,
os
=
'
std
::
ostream
&
'
,
op
=
'
const
operation
&
'
,
using
=
'
migraph
::
operation_stream
::
operator
<<
'
)
)
)
%>
%>
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment