Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
99ee76c0
Commit
99ee76c0
authored
Sep 24, 2018
by
Paul
Browse files
Merge branch 'master' into mem_color_separate_literal-master
parents
85c2c29d
f9f4f713
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
821 additions
and
156 deletions
+821
-156
src/include/migraph/verify.hpp
src/include/migraph/verify.hpp
+1
-1
src/include/migraph/verify_args.hpp
src/include/migraph/verify_args.hpp
+36
-2
src/instruction.cpp
src/instruction.cpp
+166
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+21
-15
src/onnx/verify_onnx.cpp
src/onnx/verify_onnx.cpp
+43
-11
src/program.cpp
src/program.cpp
+59
-47
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+8
-8
src/targets/cpu/cpu_lowering.cpp
src/targets/cpu/cpu_lowering.cpp
+70
-12
src/targets/gpu/eliminate_workspace.cpp
src/targets/gpu/eliminate_workspace.cpp
+3
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+5
-5
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+80
-33
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+2
-2
src/targets/gpu/write_literals.cpp
src/targets/gpu/write_literals.cpp
+2
-2
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+160
-0
test/eval_test.cpp
test/eval_test.cpp
+4
-4
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+79
-7
test/include/rob.hpp
test/include/rob.hpp
+48
-0
test/operation.cpp
test/operation.cpp
+2
-0
test/validate.cpp
test/validate.cpp
+7
-1
tools/include/operation.hpp
tools/include/operation.hpp
+25
-3
No files found.
src/include/migraph/verify.hpp
View file @
99ee76c0
...
@@ -140,7 +140,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff)
...
@@ -140,7 +140,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff)
{
{
return
mismatch_idx
(
r1
,
r2
,
[
&
](
auto
x
,
auto
y
)
{
return
mismatch_idx
(
r1
,
r2
,
[
&
](
auto
x
,
auto
y
)
{
auto
d
=
abs_diff
(
x
,
y
);
auto
d
=
abs_diff
(
x
,
y
);
return
!
(
d
>
diff
&&
d
<
diff
);
return
float_equal
(
d
,
diff
);
});
});
}
}
...
...
src/include/migraph/verify_args.hpp
View file @
99ee76c0
...
@@ -6,14 +6,16 @@
...
@@ -6,14 +6,16 @@
namespace
migraph
{
namespace
migraph
{
inline
void
verify_args
(
const
std
::
string
&
name
,
inline
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
cpu_arg
,
const
argument
&
cpu_arg
,
const
argument
&
gpu_arg
,
const
argument
&
gpu_arg
,
double
tolerance
=
80
)
double
tolerance
=
80
)
{
{
bool
passed
=
true
;
visit_all
(
cpu_arg
,
gpu_arg
)([
&
](
auto
cpu
,
auto
gpu
)
{
visit_all
(
cpu_arg
,
gpu_arg
)([
&
](
auto
cpu
,
auto
gpu
)
{
double
error
;
double
error
;
if
(
not
verify_range
(
cpu
,
gpu
,
tolerance
,
&
error
))
passed
=
verify_range
(
cpu
,
gpu
,
tolerance
,
&
error
);
if
(
not
passed
)
{
{
// TODO: Check for nans
// TODO: Check for nans
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
std
::
cout
<<
"FAILED: "
<<
name
<<
std
::
endl
;
...
@@ -27,6 +29,9 @@ inline void verify_args(const std::string& name,
...
@@ -27,6 +29,9 @@ inline void verify_args(const std::string& name,
if
(
range_zero
(
gpu
))
if
(
range_zero
(
gpu
))
std
::
cout
<<
"Gpu data is all zeros"
<<
std
::
endl
;
std
::
cout
<<
"Gpu data is all zeros"
<<
std
::
endl
;
auto
mxdiff
=
max_diff
(
cpu
,
gpu
);
std
::
cout
<<
"Max diff: "
<<
mxdiff
<<
std
::
endl
;
auto
idx
=
mismatch_idx
(
cpu
,
gpu
,
float_equal
);
auto
idx
=
mismatch_idx
(
cpu
,
gpu
,
float_equal
);
if
(
idx
<
range_distance
(
cpu
))
if
(
idx
<
range_distance
(
cpu
))
{
{
...
@@ -45,7 +50,36 @@ inline void verify_args(const std::string& name,
...
@@ -45,7 +50,36 @@ inline void verify_args(const std::string& name,
<<
gpu
[
gpu_nan_idx
]
<<
std
::
endl
;
<<
gpu
[
gpu_nan_idx
]
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
}
else
{
if
(
range_zero
(
cpu
))
std
::
cout
<<
"Cpu data is all zeros"
<<
std
::
endl
;
if
(
range_zero
(
gpu
))
std
::
cout
<<
"Gpu data is all zeros"
<<
std
::
endl
;
// auto mxdiff = max_diff(cpu, gpu);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(cpu, gpu, float_equal);
// if(idx < range_distance(cpu))
// {
// std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
// << std::endl;
// }
auto
cpu_nan_idx
=
find_idx
(
cpu
,
not_finite
);
if
(
cpu_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in cpu at "
<<
cpu_nan_idx
<<
": "
<<
cpu
[
cpu_nan_idx
]
<<
std
::
endl
;
auto
gpu_nan_idx
=
find_idx
(
gpu
,
not_finite
);
if
(
gpu_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in gpu at "
<<
gpu_nan_idx
<<
": "
<<
gpu
[
gpu_nan_idx
]
<<
std
::
endl
;
// std::cout << std::endl;
}
});
});
return
passed
;
}
}
}
// namespace migraph
}
// namespace migraph
...
...
src/instruction.cpp
0 → 100644
View file @
99ee76c0
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
namespace
migraph
{
instruction
::
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
{
}
instruction
::
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
lit
(
std
::
move
(
l
))
{
}
void
instruction
::
replace
(
const
shape
&
r
)
{
if
(
r
!=
result
)
{
result
=
r
;
for
(
auto
&&
ins
:
output
)
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
ins
->
recompute_shape
();
}
}
}
void
instruction
::
recompute_shape
()
{
replace
(
compute_shape
(
op
,
arguments
));
}
void
instruction
::
clear_arguments
()
{
for
(
auto
&&
arg
:
arguments
)
{
arg
->
remove_output
(
*
this
);
}
arguments
.
clear
();
}
bool
operator
==
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
std
::
addressof
(
i
)
==
std
::
addressof
(
*
ref
);
}
bool
instruction
::
valid
(
instruction_ref
start
)
const
{
return
valid
()
&&
std
::
all_of
(
arguments
.
begin
(),
arguments
.
end
(),
[
&
](
instruction_ref
i
)
{
auto
self
=
std
::
find
(
i
->
outputs
().
begin
(),
i
->
outputs
().
end
(),
*
this
);
return
self
!=
i
->
outputs
().
end
()
&&
std
::
distance
(
start
,
i
)
<
std
::
distance
(
start
,
*
self
);
});
}
bool
instruction
::
valid
()
const
{
shape
computed
;
if
(
op
.
name
()
==
"@literal"
)
{
computed
=
lit
.
get_shape
();
}
else
if
(
op
.
name
()
==
"@param"
)
{
computed
=
result
;
}
else
{
try
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
{
return
false
;
}
}
return
result
==
computed
&&
std
::
all_of
(
output
.
begin
(),
output
.
end
(),
[
&
](
instruction_ref
i
)
{
return
std
::
find
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
*
this
)
!=
i
->
inputs
().
end
();
});
}
shape
instruction
::
get_shape
()
const
{
return
result
;
}
const
literal
&
instruction
::
get_literal
()
const
{
assert
(
op
.
name
()
==
"@literal"
);
return
lit
;
}
const
operation
&
instruction
::
get_operator
()
const
{
return
op
;
}
std
::
string
instruction
::
name
()
const
{
return
op
.
name
();
}
const
std
::
vector
<
instruction_ref
>&
instruction
::
inputs
()
const
{
return
arguments
;
}
const
std
::
vector
<
instruction_ref
>&
instruction
::
outputs
()
const
{
return
output
;
}
bool
operator
==
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
i
==
ref
;
}
bool
operator
!=
(
const
instruction
&
i
,
instruction_ref
ref
)
{
return
!
(
i
==
ref
);
}
bool
operator
!=
(
instruction_ref
ref
,
const
instruction
&
i
)
{
return
!
(
i
==
ref
);
}
void
instruction
::
add_output
(
instruction_ref
ins
)
{
if
(
std
::
find
(
output
.
begin
(),
output
.
end
(),
ins
)
==
output
.
end
())
output
.
push_back
(
ins
);
}
template
<
class
T
>
void
instruction
::
remove_output
(
const
T
&
ins
)
{
migraph
::
erase
(
output
,
ins
);
}
void
instruction
::
backreference
(
instruction_ref
ref
)
{
for
(
auto
&&
arg
:
ref
->
inputs
())
arg
->
add_output
(
ref
);
}
void
instruction
::
replace_argument
(
instruction_ref
ins
,
instruction_ref
old
,
instruction_ref
new_ins
)
{
ins
->
replace_argument
(
old
,
new_ins
);
backreference
(
ins
);
ins
->
recompute_shape
();
}
void
instruction
::
replace
(
instruction_ref
ins
,
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
ins
->
replace
(
std
::
move
(
o
),
r
,
std
::
move
(
args
));
backreference
(
ins
);
}
void
instruction
::
replace
(
operation
o
,
const
shape
&
r
,
std
::
vector
<
instruction_ref
>
args
)
{
op
=
std
::
move
(
o
);
replace
(
r
);
replace
(
std
::
move
(
args
));
}
void
instruction
::
replace
(
std
::
vector
<
instruction_ref
>
args
)
{
clear_arguments
();
arguments
=
std
::
move
(
args
);
}
void
instruction
::
replace_argument
(
instruction_ref
old
,
instruction_ref
new_ins
)
{
std
::
replace
(
arguments
.
begin
(),
arguments
.
end
(),
old
,
new_ins
);
old
->
remove_output
(
*
this
);
}
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
op
.
compute_shape
(
shapes
);
}
}
// namespace migraph
src/onnx/onnx.cpp
View file @
99ee76c0
...
@@ -28,10 +28,6 @@ struct unknown
...
@@ -28,10 +28,6 @@ struct unknown
else
else
return
input
.
front
();
return
input
.
front
();
}
}
argument
compute
(
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPH_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
{
os
<<
x
.
name
();
os
<<
x
.
name
();
...
@@ -58,6 +54,7 @@ struct onnx_parser
...
@@ -58,6 +54,7 @@ struct onnx_parser
add_generic_op
(
"Mul"
,
mul
{});
add_generic_op
(
"Mul"
,
mul
{});
add_generic_op
(
"Relu"
,
activation
{
"relu"
});
add_generic_op
(
"Relu"
,
activation
{
"relu"
});
add_generic_op
(
"Sub"
,
sub
{});
add_generic_op
(
"Sub"
,
sub
{});
add_generic_op
(
"Sum"
,
add
{});
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
...
@@ -67,6 +64,7 @@ struct onnx_parser
...
@@ -67,6 +64,7 @@ struct onnx_parser
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -103,6 +101,15 @@ struct onnx_parser
...
@@ -103,6 +101,15 @@ struct onnx_parser
});
});
}
}
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
dims
=
args
.
front
()
->
get_shape
().
lens
();
auto
r
=
prog
.
add_instruction
(
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
args
.
front
());
auto
s
=
prog
.
add_instruction
(
softmax
{},
r
);
return
prog
.
add_instruction
(
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
}
instruction_ref
instruction_ref
parse_conv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_conv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
...
@@ -160,7 +167,7 @@ struct onnx_parser
...
@@ -160,7 +167,7 @@ struct onnx_parser
}
}
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
literal
s
=
args
[
1
]
->
lit
;
literal
s
=
args
[
1
]
->
get_literal
()
;
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
args
[
0
]);
...
@@ -344,11 +351,10 @@ struct onnx_parser
...
@@ -344,11 +351,10 @@ struct onnx_parser
if
(
node
.
name
().
empty
())
if
(
node
.
name
().
empty
())
{
{
std
::
string
generated
=
"migraph_unnamed_node"
;
std
::
string
generated
=
"migraph_unnamed_node"
;
for
(
auto
&&
output
:
node
.
output
())
return
std
::
accumulate
(
node
.
output
().
begin
(),
{
node
.
output
().
end
(),
generated
+=
"_"
+
output
;
generated
,
}
[](
auto
x
,
auto
y
)
{
return
x
+
"_"
+
y
;
});
return
generated
;
}
}
return
node
.
name
();
return
node
.
name
();
}
}
...
@@ -481,11 +487,11 @@ struct onnx_parser
...
@@ -481,11 +487,11 @@ struct onnx_parser
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::tra
ns
f
or
m
auto
&&
tensor_dims
=
t
.
te
nsor
_type
().
shape
().
dim
();
for
(
auto
&&
d
:
t
.
te
nsor
_type
().
shape
().
dim
()
)
std
::
tra
ns
f
or
m
(
tensor_dims
.
begin
()
,
{
tensor_dims
.
end
(),
dims
.
push_back
(
d
.
dim_value
());
std
::
back_inserter
(
dims
),
}
[](
auto
&&
d
)
{
return
d
.
dim_value
();
});
return
{
shape_type
,
dims
};
return
{
shape_type
,
dims
};
}
}
};
};
...
...
src/onnx/verify_onnx.cpp
View file @
99ee76c0
...
@@ -51,48 +51,76 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
...
@@ -51,48 +51,76 @@ void verify_program(const std::string& name, F f, double tolerance = 100)
auto
x
=
run_cpu
(
f
);
auto
x
=
run_cpu
(
f
);
auto
y
=
run_gpu
(
f
);
auto
y
=
run_gpu
(
f
);
migraph
::
verify_args
(
name
,
x
,
y
,
tolerance
);
migraph
::
verify_args
(
name
,
x
,
y
,
tolerance
);
// std::cout << "cpu: " << x << std::endl;
// std::cout << "gpu: " << y << std::endl;
}
}
void
verify_instructions
(
const
migraph
::
program
&
prog
,
double
tolerance
=
80
)
void
verify_instructions
(
const
migraph
::
program
&
prog
,
double
tolerance
=
80
)
{
{
for
(
auto
&&
ins
:
prog
)
for
(
auto
&&
ins
:
prog
)
{
{
if
(
ins
.
op
.
name
().
front
()
==
'@'
)
if
(
ins
.
name
().
front
()
==
'@'
)
continue
;
continue
;
if
(
ins
.
op
.
name
()
==
"broadcast"
)
if
(
ins
.
name
()
==
"broadcast"
)
continue
;
continue
;
if
(
ins
.
op
.
name
()
==
"transpose"
)
if
(
ins
.
name
()
==
"transpose"
)
continue
;
continue
;
if
(
ins
.
op
.
name
()
==
"reshape"
)
if
(
ins
.
name
()
==
"reshape"
)
continue
;
continue
;
auto
create_program
=
[
&
]
{
auto
create_program
=
[
&
]
{
migraph
::
program
p
;
migraph
::
program
p
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
std
::
vector
<
migraph
::
instruction_ref
>
inputs
;
for
(
auto
&&
arg
:
ins
.
arguments
)
for
(
auto
&&
arg
:
ins
.
inputs
()
)
{
{
if
(
arg
->
op
.
name
()
==
"@literal"
)
if
(
arg
->
name
()
==
"@literal"
)
inputs
.
push_back
(
p
.
add_literal
(
arg
->
lit
));
inputs
.
push_back
(
p
.
add_literal
(
arg
->
get_literal
()
));
else
else
inputs
.
push_back
(
inputs
.
push_back
(
p
.
add_parameter
(
std
::
to_string
(
inputs
.
size
()),
arg
->
get_shape
()));
p
.
add_parameter
(
std
::
to_string
(
inputs
.
size
()),
arg
->
get_shape
()));
}
}
p
.
add_instruction
(
ins
.
op
,
inputs
);
p
.
add_instruction
(
ins
.
get_operator
()
,
inputs
);
return
p
;
return
p
;
};
};
try
try
{
{
std
::
cout
<<
"Verify: "
<<
ins
.
op
.
name
()
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
ins
.
name
()
<<
std
::
endl
;
std
::
cout
<<
create_program
()
<<
std
::
endl
;
std
::
cout
<<
create_program
()
<<
std
::
endl
;
verify_program
(
ins
.
op
.
name
(),
create_program
,
tolerance
);
verify_program
(
ins
.
name
(),
create_program
,
tolerance
);
}
}
catch
(...)
catch
(...)
{
{
std
::
cout
<<
"Instruction "
<<
ins
.
op
.
name
()
<<
" threw an exception."
<<
std
::
endl
;
std
::
cout
<<
"Instruction "
<<
ins
.
name
()
<<
" threw an exception."
<<
std
::
endl
;
throw
;
throw
;
}
}
}
}
}
}
template
<
class
F
>
void
verify_reduced
(
F
f
,
int
n
,
double
tolerance
=
80
)
{
auto
create_program
=
[
&
]
{
migraph
::
program
p
=
f
();
auto
last
=
std
::
prev
(
p
.
end
(),
n
+
1
);
p
.
remove_instructions
(
last
,
p
.
end
());
return
p
;
};
std
::
cout
<<
"Verify: "
<<
std
::
endl
;
std
::
cout
<<
create_program
()
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
create_program
,
tolerance
);
}
template
<
class
F
>
void
verify_reduced_program
(
F
f
,
double
tolerance
=
80
)
{
migraph
::
program
p
=
f
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
verify_reduced
(
f
,
i
,
tolerance
);
}
}
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
{
std
::
vector
<
std
::
string
>
args
(
argv
+
1
,
argv
+
argc
);
std
::
vector
<
std
::
string
>
args
(
argv
+
1
,
argv
+
argc
);
...
@@ -106,6 +134,10 @@ int main(int argc, char const* argv[])
...
@@ -106,6 +134,10 @@ int main(int argc, char const* argv[])
{
{
verify_instructions
(
p
);
verify_instructions
(
p
);
}
}
else
if
(
std
::
any_of
(
args
.
begin
(),
args
.
end
(),
[](
const
auto
&
s
)
{
return
s
==
"-r"
;
}))
{
verify_reduced_program
([
&
]
{
return
migraph
::
parse_onnx
(
file
);
});
}
else
else
{
{
verify_program
(
file
,
[
&
]
{
return
migraph
::
parse_onnx
(
file
);
});
verify_program
(
file
,
[
&
]
{
return
migraph
::
parse_onnx
(
file
);
});
...
...
src/program.cpp
View file @
99ee76c0
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
namespace
migraph
{
namespace
migraph
{
MIGRAPH_DECLARE_ENV_VAR
(
MIGRAPH_TRACE_COMPILE
)
MIGRAPH_DECLARE_ENV_VAR
(
MIGRAPH_TRACE_COMPILE
)
MIGRAPH_DECLARE_ENV_VAR
(
MIGRAPH_TRACE_EVAL
)
struct
program_impl
struct
program_impl
{
{
...
@@ -20,7 +21,7 @@ struct program_impl
...
@@ -20,7 +21,7 @@ struct program_impl
context
ctx
;
context
ctx
;
};
};
const
operation
&
get_operation
(
instruction_ref
ins
)
{
return
ins
->
op
;
}
const
operation
&
get_operation
(
instruction_ref
ins
)
{
return
ins
->
get_operator
()
;
}
template
<
class
F
>
template
<
class
F
>
static
void
print_program
(
std
::
ostream
&
os
,
const
program
&
p
,
F
annonate
)
static
void
print_program
(
std
::
ostream
&
os
,
const
program
&
p
,
F
annonate
)
...
@@ -31,27 +32,27 @@ static void print_program(std::ostream& os, const program& p, F annonate)
...
@@ -31,27 +32,27 @@ static void print_program(std::ostream& os, const program& p, F annonate)
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
if
(
ins
->
op
.
name
()
==
"@param"
)
if
(
ins
->
name
()
==
"@param"
)
{
{
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
op
).
parameter
;
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()
).
parameter
;
}
}
os
<<
var_name
<<
" = "
;
os
<<
var_name
<<
" = "
;
os
<<
ins
->
op
;
os
<<
ins
->
get_operator
()
;
if
(
ins
->
op
.
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
if
(
ins
->
lit
.
get_shape
().
elements
()
>
10
)
if
(
ins
->
get_literal
()
.
get_shape
().
elements
()
>
10
)
os
<<
"{ ... }"
;
os
<<
"{ ... }"
;
else
else
os
<<
"{"
<<
ins
->
lit
<<
"}"
;
os
<<
"{"
<<
ins
->
get_literal
()
<<
"}"
;
}
}
if
(
!
ins
->
arguments
.
empty
())
if
(
!
ins
->
inputs
()
.
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
->
arguments
)
for
(
auto
&&
arg
:
ins
->
inputs
()
)
{
{
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
assert
(
p
.
has_instruction
(
arg
)
&&
"Instruction not found"
);
os
<<
delim
<<
names
.
at
(
arg
);
os
<<
delim
<<
names
.
at
(
arg
);
...
@@ -60,7 +61,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
...
@@ -60,7 +61,7 @@ static void print_program(std::ostream& os, const program& p, F annonate)
os
<<
")"
;
os
<<
")"
;
}
}
os
<<
" -> "
<<
ins
->
result
;
os
<<
" -> "
<<
ins
->
get_shape
()
;
annonate
(
ins
,
names
);
annonate
(
ins
,
names
);
...
@@ -92,8 +93,8 @@ instruction_ref program::insert_instruction(instruction_ref ins,
...
@@ -92,8 +93,8 @@ instruction_ref program::insert_instruction(instruction_ref ins,
// TODO: Use move
// TODO: Use move
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
auto
result
=
impl
->
instructions
.
insert
(
ins
,
{
op
,
r
,
std
::
move
(
args
)});
backreference
(
result
);
instruction
::
backreference
(
result
);
// assert(result->
arguments
== args);
// assert(result->
inputs()
== args);
assert
(
result
->
valid
(
begin
()));
assert
(
result
->
valid
(
begin
()));
return
result
;
return
result
;
}
}
...
@@ -108,8 +109,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
...
@@ -108,8 +109,7 @@ instruction_ref program::replace_instruction(instruction_ref ins,
assert
(
not
starts_with
(
op
.
name
(),
"@"
));
assert
(
not
starts_with
(
op
.
name
(),
"@"
));
shape
r
=
compute_shape
(
op
,
args
);
shape
r
=
compute_shape
(
op
,
args
);
ins
->
replace
(
op
,
r
,
std
::
move
(
args
));
instruction
::
replace
(
ins
,
op
,
r
,
std
::
move
(
args
));
backreference
(
ins
);
assert
(
ins
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
return
ins
;
return
ins
;
}
}
...
@@ -120,21 +120,21 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -120,21 +120,21 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert
(
has_instruction
(
rep
));
assert
(
has_instruction
(
rep
));
assert
(
ins
!=
rep
);
assert
(
ins
!=
rep
);
// TODO: Should it be an error if the output is empty?
// TODO: Should it be an error if the output is empty?
if
(
ins
->
output
.
empty
())
if
(
ins
->
output
s
()
.
empty
())
{
{
return
rep
;
return
rep
;
}
}
for
(
auto
&&
out
:
ins
->
output
)
for
(
auto
&&
out
:
ins
->
output
s
()
)
{
{
// TODO: Check for possible cycles
// TODO: Check for possible cycles
if
(
out
!=
rep
)
if
(
out
!=
rep
)
{
{
replace_argument
(
out
,
ins
,
rep
);
instruction
::
replace_argument
(
out
,
ins
,
rep
);
}
}
assert
(
out
->
valid
(
begin
()));
assert
(
out
->
valid
(
begin
()));
}
}
// Replacement should not be dead code unless its the last instruction
// Replacement should not be dead code unless its the last instruction
assert
(
!
rep
->
output
.
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
!
rep
->
output
s
()
.
empty
()
or
rep
==
std
::
prev
(
end
()));
assert
(
ins
->
valid
(
begin
()));
assert
(
ins
->
valid
(
begin
()));
assert
(
rep
->
valid
(
begin
()));
assert
(
rep
->
valid
(
begin
()));
return
rep
;
return
rep
;
...
@@ -143,7 +143,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
...
@@ -143,7 +143,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
instruction_ref
program
::
remove_instruction
(
instruction_ref
ins
)
instruction_ref
program
::
remove_instruction
(
instruction_ref
ins
)
{
{
assert
(
has_instruction
(
ins
));
assert
(
has_instruction
(
ins
));
assert
(
ins
->
output
.
empty
());
assert
(
ins
->
output
s
()
.
empty
());
ins
->
clear_arguments
();
ins
->
clear_arguments
();
return
impl
->
instructions
.
erase
(
ins
);
return
impl
->
instructions
.
erase
(
ins
);
}
}
...
@@ -155,7 +155,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_
...
@@ -155,7 +155,7 @@ instruction_ref program::remove_instructions(instruction_ref first, instruction_
// TODO: Check every element
// TODO: Check every element
assert
(
has_instruction
(
first
));
assert
(
has_instruction
(
first
));
std
::
for_each
(
first
,
last
,
[
&
](
instruction
&
ins
)
{
ins
.
clear_arguments
();
});
std
::
for_each
(
first
,
last
,
[
&
](
instruction
&
ins
)
{
ins
.
clear_arguments
();
});
assert
(
std
::
all_of
(
first
,
last
,
[
&
](
instruction
&
ins
)
{
return
ins
.
output
.
empty
();
}));
assert
(
std
::
all_of
(
first
,
last
,
[
&
](
instruction
&
ins
)
{
return
ins
.
output
s
()
.
empty
();
}));
return
impl
->
instructions
.
erase
(
first
,
last
);
return
impl
->
instructions
.
erase
(
first
,
last
);
}
}
...
@@ -188,9 +188,9 @@ shape program::get_parameter_shape(std::string name) const
...
@@ -188,9 +188,9 @@ shape program::get_parameter_shape(std::string name) const
{
{
auto
ins
=
std
::
find_if
(
auto
ins
=
std
::
find_if
(
impl
->
instructions
.
begin
(),
impl
->
instructions
.
end
(),
[
&
](
const
instruction
&
x
)
{
impl
->
instructions
.
begin
(),
impl
->
instructions
.
end
(),
[
&
](
const
instruction
&
x
)
{
if
(
x
.
op
.
name
()
==
"@param"
)
if
(
x
.
name
()
==
"@param"
)
{
{
return
any_cast
<
builtin
::
param
>
(
x
.
op
).
parameter
==
name
;
return
any_cast
<
builtin
::
param
>
(
x
.
get_operator
()
).
parameter
==
name
;
}
}
else
else
{
{
...
@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const
...
@@ -198,7 +198,7 @@ shape program::get_parameter_shape(std::string name) const
}
}
});
});
if
(
ins
!=
this
->
end
())
if
(
ins
!=
this
->
end
())
return
ins
->
result
;
return
ins
->
get_shape
()
;
else
else
return
{};
return
{};
}
}
...
@@ -227,10 +227,10 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
...
@@ -227,10 +227,10 @@ std::unordered_map<std::string, shape> program::get_parameter_shapes() const
std
::
unordered_map
<
std
::
string
,
shape
>
result
;
std
::
unordered_map
<
std
::
string
,
shape
>
result
;
for
(
auto
&&
ins
:
impl
->
instructions
)
for
(
auto
&&
ins
:
impl
->
instructions
)
{
{
if
(
ins
.
op
.
name
()
==
"@param"
)
if
(
ins
.
name
()
==
"@param"
)
{
{
auto
&&
name
=
any_cast
<
builtin
::
param
>
(
ins
.
op
).
parameter
;
auto
&&
name
=
any_cast
<
builtin
::
param
>
(
ins
.
get_operator
()
).
parameter
;
result
[
name
]
=
ins
.
result
;
result
[
name
]
=
ins
.
get_shape
()
;
}
}
}
}
return
result
;
return
result
;
...
@@ -248,7 +248,7 @@ std::size_t program::size() const { return impl->instructions.size(); }
...
@@ -248,7 +248,7 @@ std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref
program
::
begin
()
const
{
return
impl
->
instructions
.
begin
();
}
instruction_ref
program
::
begin
()
const
{
return
impl
->
instructions
.
begin
();
}
instruction_ref
program
::
end
()
const
{
return
impl
->
instructions
.
end
();
}
instruction_ref
program
::
end
()
const
{
return
impl
->
instructions
.
end
();
}
shape
program
::
get_shape
()
const
{
return
impl
->
instructions
.
back
().
result
;
}
shape
program
::
get_shape
()
const
{
return
impl
->
instructions
.
back
().
get_shape
()
;
}
instruction_ref
program
::
validate
()
const
instruction_ref
program
::
validate
()
const
{
{
...
@@ -277,7 +277,7 @@ void program::compile(const target& t, tracer trace)
...
@@ -277,7 +277,7 @@ void program::compile(const target& t, tracer trace)
{
{
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
auto
index
=
std
::
distance
(
impl
->
instructions
.
begin
(),
invalid
);
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
MIGRAPH_THROW
(
p
.
name
()
+
" pass produces invalid program at instruction "
+
std
::
to_string
(
index
)
+
": "
+
invalid
->
op
.
name
());
std
::
to_string
(
index
)
+
": "
+
invalid
->
name
());
}
}
trace
();
trace
();
#endif
#endif
...
@@ -303,32 +303,32 @@ argument generic_eval(const program& p,
...
@@ -303,32 +303,32 @@ argument generic_eval(const program& p,
values
.
reserve
(
16
);
values
.
reserve
(
16
);
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
->
op
.
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
ins
->
lit
.
get_argument
();
}));
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
ins
->
get_literal
()
.
get_argument
();
}));
}
}
else
if
(
ins
->
op
.
name
()
==
"@param"
)
else
if
(
ins
->
name
()
==
"@param"
)
{
{
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
params
.
at
(
any_cast
<
builtin
::
param
>
(
ins
->
op
).
parameter
);
return
params
.
at
(
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
);
}));
}));
}
}
else
if
(
ins
->
op
.
name
()
==
"@outline"
)
else
if
(
ins
->
name
()
==
"@outline"
)
{
{
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
argument
{
ins
->
result
,
nullptr
};
}));
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
return
argument
{
ins
->
get_shape
()
,
nullptr
};
}));
}
}
else
else
{
{
values
.
resize
(
ins
->
arguments
.
size
());
values
.
resize
(
ins
->
inputs
().
size
());
std
::
transform
(
ins
->
arguments
.
begin
(),
std
::
transform
(
ins
->
arguments
.
end
(),
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
values
.
begin
(),
[
&
](
instruction_ref
i
)
{
assert
(
results
.
find
(
i
)
!=
results
.
end
());
assert
(
results
.
find
(
i
)
!=
results
.
end
());
return
results
[
i
];
return
results
[
i
];
});
});
results
.
emplace
(
ins
,
results
.
emplace
(
ins
,
trace
(
ins
,
[
&
]
{
trace
(
ins
,
[
&
]
{
return
ins
->
op
.
compute
(
ctx
,
ins
->
result
,
values
);
}));
return
ins
->
get_operator
().
compute
(
ctx
,
ins
->
get_shape
(),
values
);
}));
}
}
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
assert
(
results
.
find
(
ins
)
!=
results
.
end
());
}
}
...
@@ -337,8 +337,20 @@ argument generic_eval(const program& p,
...
@@ -337,8 +337,20 @@ argument generic_eval(const program& p,
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
argument
program
::
eval
(
std
::
unordered_map
<
std
::
string
,
argument
>
params
)
const
{
{
if
(
enabled
(
MIGRAPH_TRACE_EVAL
{}))
{
auto
&
ctx
=
this
->
impl
->
ctx
;
return
generic_eval
(
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[
&
](
auto
&
ins
,
auto
f
)
{
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins
->
name
()
<<
std
::
endl
;
return
f
();
});
}
else
{
return
generic_eval
(
return
generic_eval
(
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[](
auto
&
,
auto
f
)
{
return
f
();
});
*
this
,
this
->
impl
->
ctx
,
std
::
move
(
params
),
[](
auto
&
,
auto
f
)
{
return
f
();
});
}
}
}
double
common_average
(
const
std
::
vector
<
double
>&
v
)
double
common_average
(
const
std
::
vector
<
double
>&
v
)
...
@@ -404,7 +416,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
...
@@ -404,7 +416,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for
(
auto
&&
p
:
ins_vec
)
for
(
auto
&&
p
:
ins_vec
)
{
{
double
avg
=
common_average
(
p
.
second
);
double
avg
=
common_average
(
p
.
second
);
op_times
[
p
.
first
->
op
.
name
()]
+=
avg
;
op_times
[
p
.
first
->
name
()]
+=
avg
;
total_instruction_time
+=
avg
;
total_instruction_time
+=
avg
;
}
}
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
double
calculate_overhead_time
=
total_time
-
total_instruction_time
;
...
...
src/simplify_reshapes.cpp
View file @
99ee76c0
...
@@ -25,26 +25,26 @@ void simplify_reshapes::apply(program& p) const
...
@@ -25,26 +25,26 @@ void simplify_reshapes::apply(program& p) const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
not
is_reshaper
(
ins
->
op
.
name
()))
if
(
not
is_reshaper
(
ins
->
name
()))
continue
;
continue
;
if
(
ins
->
output
.
size
()
!=
1
)
if
(
ins
->
output
s
()
.
size
()
!=
1
)
continue
;
continue
;
if
(
is_reshaper
(
ins
->
output
.
front
()
->
op
.
name
()))
if
(
is_reshaper
(
ins
->
output
s
()
.
front
()
->
name
()))
continue
;
continue
;
// Gather reshapes
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()
->
op
.
name
()))
while
(
is_reshaper
(
reshapes
.
back
()
->
name
()))
{
{
assert
(
!
reshapes
.
back
()
->
arguments
.
empty
());
assert
(
!
reshapes
.
back
()
->
inputs
()
.
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
arguments
.
front
()));
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
()
.
front
()));
reshapes
.
push_back
(
reshapes
.
back
()
->
arguments
.
front
());
reshapes
.
push_back
(
reshapes
.
back
()
->
inputs
()
.
front
());
}
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
for
(
auto
start
:
iterator_for
(
reshapes
))
{
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
result
==
(
*
start
)
->
result
and
i
!=
(
*
start
);
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
});
if
(
last
!=
reshapes
.
rend
())
if
(
last
!=
reshapes
.
rend
())
{
{
...
...
src/targets/cpu/cpu_lowering.cpp
View file @
99ee76c0
...
@@ -134,6 +134,63 @@ struct cpu_convolution
...
@@ -134,6 +134,63 @@ struct cpu_convolution
}
}
};
};
struct
cpu_im2col
{
im2col
op
;
static
std
::
string
name
()
{
return
"cpu::im2col"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
input_shape
=
args
[
0
].
get_shape
();
auto
weights_shape
=
args
[
1
].
get_shape
();
visit_all
(
result
,
args
[
0
])([
&
](
auto
col
,
auto
input
)
{
const
std
::
size_t
&
height
=
input_shape
.
lens
()[
2
];
const
std
::
size_t
&
width
=
input_shape
.
lens
()[
3
];
const
std
::
size_t
&
channels
=
weights_shape
.
lens
()[
1
];
const
std
::
size_t
&
kernel_h
=
weights_shape
.
lens
()[
2
];
const
std
::
size_t
&
kernel_w
=
weights_shape
.
lens
()[
3
];
const
std
::
size_t
&
pad_h
=
op
.
padding
[
0
];
const
std
::
size_t
&
pad_w
=
op
.
padding
[
1
];
const
std
::
size_t
&
stride_h
=
op
.
stride
[
0
];
const
std
::
size_t
&
stride_w
=
op
.
stride
[
1
];
int
kdiv2_h
,
kdiv2_w
;
kdiv2_h
=
kernel_h
/
2
;
kdiv2_w
=
kernel_w
/
2
;
// calculate output sizes
const
std
::
size_t
col_height
=
(
height
-
kernel_h
+
2
*
pad_h
)
/
stride_h
+
1
;
const
std
::
size_t
col_width
=
(
width
-
kernel_w
+
2
*
pad_w
)
/
stride_w
+
1
;
// account for padding for the starting position of the input pixels
std
::
size_t
iinput
=
kdiv2_h
-
pad_h
;
// loop over output pixels (ioutput, joutput)
for
(
std
::
size_t
ioutput
=
0
;
ioutput
<
col_height
;
ioutput
++
,
iinput
+=
stride_h
)
{
std
::
size_t
jinput
=
kdiv2_w
-
pad_w
;
for
(
std
::
size_t
joutput
=
0
;
joutput
<
col_width
;
joutput
++
,
jinput
+=
stride_w
)
{
// compute linear index for output
std
::
size_t
ldx
=
ioutput
*
col_width
+
joutput
;
std
::
size_t
p
=
0
;
dfor
(
channels
,
kernel_h
,
kernel_w
)([
&
](
std
::
size_t
c
,
std
::
size_t
koffset
,
std
::
size_t
loffset
)
{
int
idx
=
iinput
+
koffset
-
kdiv2_h
;
int
jdx
=
jinput
+
loffset
-
kdiv2_w
;
col
(
ldx
,
p
)
=
((
idx
>=
0
)
&&
(
idx
<
height
)
&&
(
jdx
>=
0
)
&&
(
jdx
<
width
))
?
input
(
0
,
c
,
idx
,
jdx
)
:
0
;
p
++
;
});
}
}
});
return
result
;
}
};
struct
max_pool
struct
max_pool
{
{
static
std
::
string
name
()
{
return
"max"
;
}
static
std
::
string
name
()
{
return
"max"
;
}
...
@@ -494,6 +551,7 @@ struct cpu_apply
...
@@ -494,6 +551,7 @@ struct cpu_apply
void
init
()
void
init
()
{
{
apply_map
[
"im2col"
]
=
extend_op
<
cpu_im2col
,
im2col
>
();
apply_map
[
"convolution"
]
=
extend_op
<
cpu_convolution
,
convolution
>
();
apply_map
[
"convolution"
]
=
extend_op
<
cpu_convolution
,
convolution
>
();
apply_map
[
"gemm"
]
=
extend_op
<
cpu_gemm
,
gemm
>
();
apply_map
[
"gemm"
]
=
extend_op
<
cpu_gemm
,
gemm
>
();
apply_map
[
"batch_norm_inference"
]
=
apply_map
[
"batch_norm_inference"
]
=
...
@@ -521,17 +579,17 @@ struct cpu_apply
...
@@ -521,17 +579,17 @@ struct cpu_apply
init
();
init
();
for
(
auto
it
:
iterator_for
(
*
prog
))
for
(
auto
it
:
iterator_for
(
*
prog
))
{
{
if
(
it
->
op
.
name
()
==
"activation"
)
if
(
it
->
name
()
==
"activation"
)
{
{
apply_activation
(
it
);
apply_activation
(
it
);
}
}
else
if
(
it
->
op
.
name
()
==
"pooling"
)
else
if
(
it
->
name
()
==
"pooling"
)
{
{
apply_pooling
(
it
);
apply_pooling
(
it
);
}
}
else
if
(
apply_map
.
count
(
it
->
op
.
name
())
>
0
)
else
if
(
apply_map
.
count
(
it
->
name
())
>
0
)
{
{
apply_map
.
at
(
it
->
op
.
name
())(
it
);
apply_map
.
at
(
it
->
name
())(
it
);
}
}
}
}
}
}
...
@@ -539,30 +597,30 @@ struct cpu_apply
...
@@ -539,30 +597,30 @@ struct cpu_apply
template
<
class
T
>
template
<
class
T
>
void
apply_simple_op
(
instruction_ref
ins
)
void
apply_simple_op
(
instruction_ref
ins
)
{
{
prog
->
replace_instruction
(
ins
,
T
{},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
T
{},
ins
->
inputs
()
);
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
apply_extend_op
(
instruction_ref
ins
)
void
apply_extend_op
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
get_operator
()
);
prog
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
T
{
op
},
ins
->
inputs
()
);
}
}
void
apply_activation
(
instruction_ref
ins
)
void
apply_activation
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
get_operator
()
);
if
(
op
.
mode
==
"relu"
)
if
(
op
.
mode
==
"relu"
)
prog
->
replace_instruction
(
ins
,
cpu_unary
<
relu_op
>
{},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_unary
<
relu_op
>
{},
ins
->
inputs
()
);
}
}
void
apply_pooling
(
instruction_ref
ins
)
void
apply_pooling
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
get_operator
()
);
if
(
op
.
mode
==
"max"
)
if
(
op
.
mode
==
"max"
)
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
max_pool
>
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
max_pool
>
{
op
},
ins
->
inputs
()
);
else
if
(
op
.
mode
==
"average"
)
else
if
(
op
.
mode
==
"average"
)
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
avg_pool
>
{
op
},
ins
->
arguments
);
prog
->
replace_instruction
(
ins
,
cpu_pooling
<
avg_pool
>
{
op
},
ins
->
inputs
()
);
}
}
};
};
...
...
src/targets/gpu/eliminate_workspace.cpp
View file @
99ee76c0
...
@@ -20,11 +20,11 @@ void eliminate_workspace::apply(program& p) const
...
@@ -20,11 +20,11 @@ void eliminate_workspace::apply(program& p) const
std
::
vector
<
instruction_ref
>
allocs
;
std
::
vector
<
instruction_ref
>
allocs
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
->
output
.
size
()
!=
1
)
if
(
ins
->
output
s
()
.
size
()
!=
1
)
continue
;
continue
;
if
(
ins
->
op
.
name
()
!=
"hip::allocate"
)
if
(
ins
->
name
()
!=
"hip::allocate"
)
continue
;
continue
;
auto
&&
a
=
any_cast
<
hip_allocate
>
(
ins
->
op
);
auto
&&
a
=
any_cast
<
hip_allocate
>
(
ins
->
get_operator
()
);
if
(
a
.
tag
==
"workspace"
)
if
(
a
.
tag
==
"workspace"
)
{
{
n
=
std
::
max
(
n
,
ins
->
get_shape
().
bytes
());
n
=
std
::
max
(
n
,
ins
->
get_shape
().
bytes
());
...
...
src/targets/gpu/fuse_ops.cpp
View file @
99ee76c0
...
@@ -26,14 +26,14 @@ void fuse_ops::apply(program& p) const
...
@@ -26,14 +26,14 @@ void fuse_ops::apply(program& p) const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
->
op
.
name
()
!=
"gpu::relu"
)
if
(
ins
->
name
()
!=
"gpu::relu"
)
continue
;
continue
;
auto
add_ins
=
ins
->
arguments
.
front
();
auto
add_ins
=
ins
->
inputs
()
.
front
();
if
(
add_ins
->
op
.
name
()
!=
"gpu::add"
)
if
(
add_ins
->
name
()
!=
"gpu::add"
)
continue
;
continue
;
auto
args
=
add_ins
->
arguments
;
auto
args
=
add_ins
->
inputs
()
;
// Use the allocation from the relu operator
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
arguments
.
back
();
args
.
back
()
=
ins
->
inputs
()
.
back
();
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
p
.
replace_instruction
(
ins
,
hip_add_relu
{},
args
);
}
}
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
99ee76c0
...
@@ -132,8 +132,16 @@ struct miopen_convolution
...
@@ -132,8 +132,16 @@ struct miopen_convolution
workspace_size
,
workspace_size
,
false
);
false
);
algo
=
perf
.
fwd_algo
;
algo
=
perf
.
fwd_algo
;
return
algo
==
miopenConvolutionFwdAlgoWinograd
?
shape
{
shape
::
int8_type
,
{
0
}}
return
shape
{
shape
::
int8_type
,
{
perf
.
memory
}};
:
workspace_shape
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
miopen_convolution
&
self
)
{
os
<<
self
.
name
()
<<
"["
;
os
<<
self
.
op
<<
", "
;
os
<<
"algo="
<<
self
.
algo
;
os
<<
"]"
;
return
os
;
}
}
};
};
...
@@ -308,6 +316,34 @@ struct miopen_relu
...
@@ -308,6 +316,34 @@ struct miopen_relu
}
}
};
};
struct
miopen_softmax
{
softmax
op
;
std
::
string
name
()
const
{
return
"gpu::softmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
float
alpha
=
1
,
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
miopenSoftmaxForward
(
ctx
.
handle
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
&
beta
,
y_desc
.
get
(),
args
[
1
].
implicit
());
return
args
[
1
];
}
};
struct
miopen_apply
struct
miopen_apply
{
{
program
*
prog
=
nullptr
;
program
*
prog
=
nullptr
;
...
@@ -325,34 +361,38 @@ struct miopen_apply
...
@@ -325,34 +361,38 @@ struct miopen_apply
for
(
auto
it
=
prog
->
begin
();
it
!=
prog
->
end
();
it
++
)
for
(
auto
it
=
prog
->
begin
();
it
!=
prog
->
end
();
it
++
)
{
{
auto
s
=
it
->
get_shape
();
auto
s
=
it
->
get_shape
();
if
(
it
->
op
.
name
()
==
"convolution"
)
if
(
it
->
name
()
==
"convolution"
)
{
{
check_shape
(
s
,
apply_convolution
(
it
));
check_shape
(
s
,
apply_convolution
(
it
));
}
}
else
if
(
it
->
op
.
name
()
==
"activation"
)
else
if
(
it
->
name
()
==
"activation"
)
{
{
check_shape
(
s
,
apply_activation
(
it
));
check_shape
(
s
,
apply_activation
(
it
));
}
}
else
if
(
it
->
op
.
name
()
==
"pooling"
)
else
if
(
it
->
name
()
==
"pooling"
)
{
{
check_shape
(
s
,
apply_pooling
(
it
));
check_shape
(
s
,
apply_pooling
(
it
));
}
}
else
if
(
it
->
op
.
name
()
==
"add"
)
else
if
(
it
->
name
()
==
"add"
)
{
{
check_shape
(
s
,
apply_add
(
it
));
check_shape
(
s
,
apply_add
(
it
));
}
}
else
if
(
it
->
op
.
name
()
==
"gemm"
)
else
if
(
it
->
name
()
==
"gemm"
)
{
{
check_shape
(
s
,
apply_gemm
(
it
));
check_shape
(
s
,
apply_gemm
(
it
));
}
}
else
if
(
it
->
op
.
name
()
==
"contiguous"
)
else
if
(
it
->
name
()
==
"contiguous"
)
{
{
check_shape
(
s
,
apply_contiguous
(
it
));
check_shape
(
s
,
apply_contiguous
(
it
));
}
}
else
if
(
it
->
op
.
name
()
==
"batch_norm_inference"
)
else
if
(
it
->
name
()
==
"batch_norm_inference"
)
{
{
check_shape
(
s
,
apply_batch_norm_inference
(
it
));
check_shape
(
s
,
apply_batch_norm_inference
(
it
));
}
}
else
if
(
it
->
name
()
==
"softmax"
)
{
check_shape
(
s
,
apply_softmax
(
it
));
}
}
}
}
}
...
@@ -372,78 +412,85 @@ struct miopen_apply
...
@@ -372,78 +412,85 @@ struct miopen_apply
instruction_ref
apply_convolution
(
instruction_ref
ins
)
instruction_ref
apply_convolution
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
convolution
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
convolution
>
(
ins
->
get_operator
()
);
auto
conv
=
miopen_convolution
{
op
,
make_conv
(
op
)};
auto
conv
=
miopen_convolution
{
op
,
make_conv
(
op
)};
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
result
,
ins
->
arguments
);
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
ins
->
inputs
()
);
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
conv
,
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
workspace
,
output
);
ins
,
conv
,
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
workspace
,
output
);
}
}
instruction_ref
apply_pooling
(
instruction_ref
ins
)
instruction_ref
apply_pooling
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
pooling
>
(
ins
->
get_operator
()
);
auto
pd
=
make_pooling
(
op
);
auto
pd
=
make_pooling
(
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_pooling
{
op
,
std
::
move
(
pd
)},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
instruction_ref
apply_activation
(
instruction_ref
ins
)
instruction_ref
apply_activation
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
activation
>
(
ins
->
get_operator
()
);
auto
ad
=
make_relu
();
auto
ad
=
make_relu
();
if
(
op
.
mode
==
"relu"
)
if
(
op
.
mode
==
"relu"
)
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
arguments
.
at
(
0
),
output
);
ins
,
miopen_relu
{
std
::
move
(
ad
)},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
return
ins
;
return
ins
;
}
}
instruction_ref
apply_softmax
(
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
softmax
>
(
ins
->
get_operator
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
ins
,
miopen_softmax
{
op
},
ins
->
inputs
().
at
(
0
),
output
);
}
instruction_ref
apply_add
(
instruction_ref
ins
)
instruction_ref
apply_add
(
instruction_ref
ins
)
{
{
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
hip_add
{},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
hip_add
{},
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
output
);
}
}
instruction_ref
apply_gemm
(
instruction_ref
ins
)
instruction_ref
apply_gemm
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
gemm
>
(
ins
->
get_operator
()
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
return
prog
->
replace_instruction
(
return
prog
->
replace_instruction
(
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
ins
,
miopen_gemm
{
op
},
ins
->
inputs
()
.
at
(
0
),
ins
->
inputs
()
.
at
(
1
),
output
);
}
}
instruction_ref
apply_contiguous
(
instruction_ref
ins
)
instruction_ref
apply_contiguous
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
get_operator
()
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
return
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
return
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
inputs
()
.
at
(
0
),
output
);
}
}
instruction_ref
apply_batch_norm_inference
(
instruction_ref
ins
)
instruction_ref
apply_batch_norm_inference
(
instruction_ref
ins
)
{
{
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
get_operator
()
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
()
);
shape
old_shape
=
ins
->
arguments
.
at
(
1
)
->
get_shape
();
shape
old_shape
=
ins
->
inputs
()
.
at
(
1
)
->
get_shape
();
std
::
vector
<
int64_t
>
new_shape
{
1
,
static_cast
<
int64_t
>
(
old_shape
.
elements
()),
1
,
1
};
std
::
vector
<
int64_t
>
new_shape
{
1
,
static_cast
<
int64_t
>
(
old_shape
.
elements
()),
1
,
1
};
auto
reshape_op
=
reshape
{
new_shape
};
auto
reshape_op
=
reshape
{
new_shape
};
std
::
vector
<
instruction_ref
>
reshapes
;
std
::
vector
<
instruction_ref
>
reshapes
;
std
::
transform
(
ins
->
arguments
.
begin
()
+
1
,
std
::
transform
(
ins
->
inputs
()
.
begin
()
+
1
,
ins
->
arguments
.
end
(),
ins
->
inputs
()
.
end
(),
std
::
back_inserter
(
reshapes
),
std
::
back_inserter
(
reshapes
),
[
&
](
auto
i
)
{
return
prog
->
insert_instruction
(
ins
,
reshape_op
,
i
);
});
[
&
](
auto
i
)
{
return
prog
->
insert_instruction
(
ins
,
reshape_op
,
i
);
});
return
prog
->
replace_instruction
(
ins
,
return
prog
->
replace_instruction
(
ins
,
miopen_batch_norm_inference
{
op
},
miopen_batch_norm_inference
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
inputs
()
.
at
(
0
),
reshapes
[
0
],
reshapes
[
0
],
reshapes
[
1
],
reshapes
[
1
],
reshapes
[
2
],
reshapes
[
2
],
...
...
src/targets/gpu/target.cpp
View file @
99ee76c0
...
@@ -32,11 +32,11 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
...
@@ -32,11 +32,11 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
memory_coloring
{
"hip::allocate"
},
memory_coloring
{
"hip::allocate"
},
fuse_ops
{},
fuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_workspace
{},
eliminate_contiguous
{},
eliminate_contiguous
{},
dead_code_elimination
{},
dead_code_elimination
{},
write_literals
{
&
ctx
},
write_literals
{
&
ctx
},
eliminate_allocation
{
""
},
eliminate_workspace
{},
eliminate_allocation
{
"hip::allocate"
},
check_context
<
context
>
{},
check_context
<
context
>
{},
dead_code_elimination
{}
dead_code_elimination
{}
};
};
...
...
src/targets/gpu/write_literals.cpp
View file @
99ee76c0
...
@@ -28,9 +28,9 @@ void write_literals::apply(program& p) const
...
@@ -28,9 +28,9 @@ void write_literals::apply(program& p) const
assert
(
ctx
!=
nullptr
);
assert
(
ctx
!=
nullptr
);
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
if
(
ins
->
op
.
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
argument
a
=
to_gpu
(
ins
->
lit
.
get_argument
());
argument
a
=
to_gpu
(
ins
->
get_literal
()
.
get_argument
());
std
::
size_t
n
=
ctx
->
literals
.
size
();
std
::
size_t
n
=
ctx
->
literals
.
size
();
ctx
->
literals
.
push_back
(
a
);
ctx
->
literals
.
push_back
(
a
);
p
.
replace_instruction
(
ins
,
hip_load_literal
{
a
.
get_shape
(),
n
});
p
.
replace_instruction
(
ins
,
hip_load_literal
{
a
.
get_shape
(),
n
});
...
...
test/cpu_ops_test.cpp
View file @
99ee76c0
...
@@ -6,6 +6,132 @@
...
@@ -6,6 +6,132 @@
#include <migraph/verify.hpp>
#include <migraph/verify.hpp>
#include "test.hpp"
#include "test.hpp"
void
im2col_3x3_no_pad_identity_test
()
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
3
,
3
};
std
::
array
<
std
::
size_t
,
2
>
padding
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
{{
1
,
1
}};
std
::
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
weights
(
channels
*
f
[
0
]
*
f
[
1
]);
std
::
vector
<
int32_t
>
input
(
channels
*
size
[
0
]
*
size
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraph
::
program
p
;
migraph
::
shape
s_image
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
size
[
0
],
size
[
1
]}};
migraph
::
shape
s_weights
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
f
[
0
],
f
[
1
]}};
auto
l_image
=
p
.
add_literal
(
migraph
::
literal
{
s_image
,
input
});
auto
l_weights
=
p
.
add_literal
(
migraph
::
literal
{
s_weights
,
weights
});
p
.
add_instruction
(
migraph
::
im2col
{
padding
,
stride
,
dilation
},
l_image
,
l_weights
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
size_t
col_height
=
(
size
[
0
]
-
f
[
0
]
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
;
std
::
size_t
col_width
=
(
size
[
1
]
-
f
[
1
]
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
;
std
::
vector
<
float
>
results_vector
(
channels
*
f
[
0
]
*
f
[
1
]
*
col_height
*
col_width
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
input
));
}
void
im2col_3x3_no_pad_test
()
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
4
,
4
};
std
::
array
<
std
::
size_t
,
2
>
padding
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
{{
1
,
1
}};
std
::
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
weights
(
channels
*
f
[
0
]
*
f
[
1
]);
std
::
vector
<
int32_t
>
input
(
channels
*
size
[
0
]
*
size
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraph
::
program
p
;
migraph
::
shape
s_image
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
size
[
0
],
size
[
1
]}};
migraph
::
shape
s_weights
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
f
[
0
],
f
[
1
]}};
auto
l_image
=
p
.
add_literal
(
migraph
::
literal
{
s_image
,
input
});
auto
l_weights
=
p
.
add_literal
(
migraph
::
literal
{
s_weights
,
weights
});
p
.
add_instruction
(
migraph
::
im2col
{
padding
,
stride
,
dilation
},
l_image
,
l_weights
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
correct
=
{
0
,
1
,
2
,
4
,
5
,
6
,
8
,
9
,
10
,
1
,
2
,
3
,
5
,
6
,
7
,
9
,
10
,
11
,
4
,
5
,
6
,
8
,
9
,
10
,
12
,
13
,
14
,
5
,
6
,
7
,
9
,
10
,
11
,
13
,
14
,
15
};
std
::
size_t
col_height
=
(
size
[
0
]
-
f
[
0
]
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
;
std
::
size_t
col_width
=
(
size
[
1
]
-
f
[
1
]
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
;
std
::
vector
<
float
>
results_vector
(
channels
*
f
[
0
]
*
f
[
1
]
*
col_height
*
col_width
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
correct
));
}
void
im2col_3x3_stride_2_no_pad_test
()
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
6
,
6
};
std
::
array
<
std
::
size_t
,
2
>
padding
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
{{
2
,
2
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
{{
1
,
1
}};
std
::
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
weights
(
channels
*
f
[
0
]
*
f
[
1
]);
std
::
vector
<
int32_t
>
input
(
channels
*
size
[
0
]
*
size
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraph
::
program
p
;
migraph
::
shape
s_image
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
size
[
0
],
size
[
1
]}};
migraph
::
shape
s_weights
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
f
[
0
],
f
[
1
]}};
auto
l_image
=
p
.
add_literal
(
migraph
::
literal
{
s_image
,
input
});
auto
l_weights
=
p
.
add_literal
(
migraph
::
literal
{
s_weights
,
weights
});
p
.
add_instruction
(
migraph
::
im2col
{
padding
,
stride
,
dilation
},
l_image
,
l_weights
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
correct
=
{
0
,
1
,
2
,
6
,
7
,
8
,
12
,
13
,
14
,
2
,
3
,
4
,
8
,
9
,
10
,
14
,
15
,
16
,
12
,
13
,
14
,
18
,
19
,
20
,
24
,
25
,
26
,
14
,
15
,
16
,
20
,
21
,
22
,
26
,
27
,
28
};
std
::
size_t
col_height
=
(
size
[
0
]
-
f
[
0
]
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
;
std
::
size_t
col_width
=
(
size
[
1
]
-
f
[
1
]
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
;
std
::
vector
<
float
>
results_vector
(
channels
*
f
[
0
]
*
f
[
1
]
*
col_height
*
col_width
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
correct
));
}
void
im2col_3x3_with_padding_test
()
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
2
,
2
};
std
::
array
<
std
::
size_t
,
2
>
padding
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
stride
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
{{
1
,
1
}};
std
::
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
weights
(
channels
*
f
[
0
]
*
f
[
1
]);
std
::
vector
<
int32_t
>
input
(
channels
*
size
[
0
]
*
size
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraph
::
program
p
;
migraph
::
shape
s_image
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
size
[
0
],
size
[
1
]}};
migraph
::
shape
s_weights
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
f
[
0
],
f
[
1
]}};
auto
l_image
=
p
.
add_literal
(
migraph
::
literal
{
s_image
,
input
});
auto
l_weights
=
p
.
add_literal
(
migraph
::
literal
{
s_weights
,
weights
});
p
.
add_instruction
(
migraph
::
im2col
{
padding
,
stride
,
dilation
},
l_image
,
l_weights
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
correct
=
{
0
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
3
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
3
,
0
,
0
,
0
,
1
,
0
,
2
,
3
,
0
,
0
,
0
,
0
,
1
,
0
,
2
,
3
,
0
,
0
,
0
,
0
};
std
::
size_t
col_height
=
(
size
[
0
]
-
f
[
0
]
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
;
std
::
size_t
col_width
=
(
size
[
1
]
-
f
[
1
]
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
;
std
::
vector
<
float
>
results_vector
(
channels
*
f
[
0
]
*
f
[
1
]
*
col_height
*
col_width
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
correct
));
}
void
batch_norm_inference_test
()
void
batch_norm_inference_test
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
@@ -46,6 +172,35 @@ void batch_norm_inference_test()
...
@@ -46,6 +172,35 @@ void batch_norm_inference_test()
EXPECT
(
migraph
::
verify_range
(
result_vector
,
gold
));
EXPECT
(
migraph
::
verify_range
(
result_vector
,
gold
));
}
}
void
im2col_3x3_with_channels_identity_test
()
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
3
,
3
};
std
::
array
<
std
::
size_t
,
2
>
padding
{{
0
,
0
}};
std
::
array
<
std
::
size_t
,
2
>
stride
{{
1
,
1
}};
std
::
array
<
std
::
size_t
,
2
>
dilation
{{
1
,
1
}};
std
::
size_t
channels
=
2
;
std
::
vector
<
int32_t
>
weights
(
channels
*
f
[
0
]
*
f
[
1
]);
std
::
vector
<
int32_t
>
input
(
channels
*
size
[
0
]
*
size
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraph
::
program
p
;
migraph
::
shape
s_image
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
size
[
0
],
size
[
1
]}};
migraph
::
shape
s_weights
{
migraph
::
shape
::
int32_type
,
{
1
,
channels
,
f
[
0
],
f
[
1
]}};
auto
l_image
=
p
.
add_literal
(
migraph
::
literal
{
s_image
,
input
});
auto
l_weights
=
p
.
add_literal
(
migraph
::
literal
{
s_weights
,
weights
});
p
.
add_instruction
(
migraph
::
im2col
{
padding
,
stride
,
dilation
},
l_image
,
l_weights
);
p
.
compile
(
migraph
::
cpu
::
cpu_target
{});
auto
result
=
p
.
eval
({});
std
::
size_t
col_height
=
(
size
[
0
]
-
f
[
0
]
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
;
std
::
size_t
col_width
=
(
size
[
1
]
-
f
[
1
]
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
;
std
::
vector
<
float
>
results_vector
(
channels
*
f
[
0
]
*
f
[
1
]
*
col_height
*
col_width
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraph
::
verify_range
(
results_vector
,
input
));
}
void
exp_test
()
void
exp_test
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
@@ -666,4 +821,9 @@ int main()
...
@@ -666,4 +821,9 @@ int main()
conv2d_padding_test
();
conv2d_padding_test
();
conv2d_padding_stride_test
();
conv2d_padding_stride_test
();
batch_norm_inference_test
();
batch_norm_inference_test
();
im2col_3x3_no_pad_identity_test
();
im2col_3x3_no_pad_test
();
im2col_3x3_stride_2_no_pad_test
();
im2col_3x3_with_channels_identity_test
();
im2col_3x3_with_padding_test
();
}
}
test/eval_test.cpp
View file @
99ee76c0
...
@@ -21,13 +21,13 @@ struct reverse_pass
...
@@ -21,13 +21,13 @@ struct reverse_pass
{
{
for
(
auto
ins
:
migraph
::
iterator_for
(
p
))
for
(
auto
ins
:
migraph
::
iterator_for
(
p
))
{
{
if
(
ins
->
op
.
name
()
==
"sum"
)
if
(
ins
->
name
()
==
"sum"
)
{
{
p
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
arguments
);
p
.
replace_instruction
(
ins
,
minus_op
{},
ins
->
inputs
()
);
}
}
else
if
(
ins
->
op
.
name
()
==
"minus"
)
else
if
(
ins
->
name
()
==
"minus"
)
{
{
p
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
arguments
);
p
.
replace_instruction
(
ins
,
sum_op
{},
ins
->
inputs
()
);
}
}
}
}
}
}
...
...
test/gpu/miopen.cpp
View file @
99ee76c0
...
@@ -97,10 +97,10 @@ void compile_check(migraph::program& p, const migraph::target& t)
...
@@ -97,10 +97,10 @@ void compile_check(migraph::program& p, const migraph::target& t)
}
}
template
<
class
V
>
template
<
class
V
>
migraph
::
argument
run_cpu
()
migraph
::
argument
run_cpu
(
migraph
::
program
&
p
)
{
{
V
v
;
V
v
;
auto
p
=
v
.
create_program
();
p
=
v
.
create_program
();
auto_print
pp
{
p
,
0
};
auto_print
pp
{
p
,
0
};
compile_check
(
p
,
migraph
::
cpu
::
cpu_target
{});
compile_check
(
p
,
migraph
::
cpu
::
cpu_target
{});
migraph
::
program
::
parameter_map
m
;
migraph
::
program
::
parameter_map
m
;
...
@@ -112,10 +112,10 @@ migraph::argument run_cpu()
...
@@ -112,10 +112,10 @@ migraph::argument run_cpu()
}
}
template
<
class
V
>
template
<
class
V
>
migraph
::
argument
run_gpu
()
migraph
::
argument
run_gpu
(
migraph
::
program
&
p
)
{
{
V
v
;
V
v
;
auto
p
=
v
.
create_program
();
p
=
v
.
create_program
();
auto_print
pp
{
p
,
1
};
auto_print
pp
{
p
,
1
};
compile_check
(
p
,
migraph
::
gpu
::
target
{});
compile_check
(
p
,
migraph
::
gpu
::
target
{});
migraph
::
program
::
parameter_map
m
;
migraph
::
program
::
parameter_map
m
;
...
@@ -131,9 +131,20 @@ template <class V>
...
@@ -131,9 +131,20 @@ template <class V>
void
verify_program
()
void
verify_program
()
{
{
auto_print
::
set_terminate_handler
(
migraph
::
get_type_name
<
V
>
());
auto_print
::
set_terminate_handler
(
migraph
::
get_type_name
<
V
>
());
auto
cpu_arg_f
=
detach_async
([]
{
return
run_cpu
<
V
>
();
});
migraph
::
program
cpu_prog
;
auto
gpu_arg
=
run_gpu
<
V
>
();
migraph
::
program
gpu_prog
;
verify_args
(
migraph
::
get_type_name
<
V
>
(),
cpu_arg_f
.
get
(),
gpu_arg
);
auto
cpu_arg_f
=
detach_async
([
&
]
{
return
run_cpu
<
V
>
(
cpu_prog
);
});
auto
gpu_arg
=
run_gpu
<
V
>
(
gpu_prog
);
bool
passed
=
verify_args
(
migraph
::
get_type_name
<
V
>
(),
cpu_arg_f
.
get
(),
gpu_arg
);
if
(
not
passed
)
{
V
v
;
auto
p
=
v
.
create_program
();
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
"cpu:
\n
"
<<
cpu_prog
<<
std
::
endl
;
std
::
cout
<<
"gpu:
\n
"
<<
gpu_prog
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
std
::
set_terminate
(
nullptr
);
std
::
set_terminate
(
nullptr
);
}
}
...
@@ -235,6 +246,28 @@ struct test_add_broadcast5
...
@@ -235,6 +246,28 @@ struct test_add_broadcast5
}
}
};
};
struct
test_softmax
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
,
4
,
2
}});
p
.
add_instruction
(
migraph
::
softmax
{},
x
);
return
p
;
}
};
struct
test_softmax2
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
1000
,
1
,
1
}});
p
.
add_instruction
(
migraph
::
softmax
{},
x
);
return
p
;
}
};
struct
test_conv
struct
test_conv
{
{
migraph
::
program
create_program
()
const
migraph
::
program
create_program
()
const
...
@@ -248,6 +281,20 @@ struct test_conv
...
@@ -248,6 +281,20 @@ struct test_conv
}
}
};
};
struct
test_conv2
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
input
=
p
.
add_parameter
(
"x"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
512
,
28
,
28
}});
auto
weights
=
p
.
add_parameter
(
"w"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
256
,
512
,
1
,
1
}});
p
.
add_instruction
(
migraph
::
convolution
{{
0
,
0
},
{
1
,
1
},
{
1
,
1
}},
input
,
weights
);
return
p
;
}
};
struct
test_conv_relu
struct
test_conv_relu
{
{
migraph
::
program
create_program
()
const
migraph
::
program
create_program
()
const
...
@@ -428,6 +475,27 @@ struct test_batchnorm_inference
...
@@ -428,6 +475,27 @@ struct test_batchnorm_inference
}
}
};
};
struct
test_conv_bn
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
migraph
::
shape
xs
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
224
,
224
}};
migraph
::
shape
ws
{
migraph
::
shape
::
float_type
,
{
64
,
3
,
7
,
7
}};
migraph
::
shape
vars
{
migraph
::
shape
::
float_type
,
{
64
}};
auto
x
=
p
.
add_parameter
(
"x"
,
xs
);
auto
w
=
p
.
add_parameter
(
"w"
,
ws
);
auto
conv
=
p
.
add_instruction
(
migraph
::
convolution
{{
3
,
3
},
{
2
,
2
},
{
1
,
1
}},
x
,
w
);
auto
scale
=
p
.
add_literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
1
)));
auto
bias
=
p
.
add_literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
2
)));
auto
mean
=
p
.
add_literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
3
)));
auto
variance
=
p
.
add_literal
(
migraph
::
abs
(
migraph
::
generate_literal
(
vars
,
4
)));
p
.
add_instruction
(
migraph
::
batch_norm_inference
{},
conv
,
scale
,
bias
,
mean
,
variance
);
return
p
;
}
};
struct
test_conv_bn_relu_pooling
struct
test_conv_bn_relu_pooling
{
{
migraph
::
program
create_program
()
const
migraph
::
program
create_program
()
const
...
@@ -495,7 +563,10 @@ int main()
...
@@ -495,7 +563,10 @@ int main()
verify_program
<
test_add_broadcast3
>
();
verify_program
<
test_add_broadcast3
>
();
verify_program
<
test_add_broadcast4
>
();
verify_program
<
test_add_broadcast4
>
();
verify_program
<
test_add_broadcast5
>
();
verify_program
<
test_add_broadcast5
>
();
verify_program
<
test_softmax
>
();
verify_program
<
test_softmax2
>
();
verify_program
<
test_conv
>
();
verify_program
<
test_conv
>
();
verify_program
<
test_conv2
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_conv_pooling
>
();
...
@@ -508,6 +579,7 @@ int main()
...
@@ -508,6 +579,7 @@ int main()
verify_program
<
test_transpose
>
();
verify_program
<
test_transpose
>
();
verify_program
<
test_batchnorm_inference
>
();
verify_program
<
test_batchnorm_inference
>
();
verify_program
<
test_batchnorm_inference_2
>
();
verify_program
<
test_batchnorm_inference_2
>
();
verify_program
<
test_conv_bn
>
();
verify_program
<
test_conv_bn_relu_pooling
>
();
verify_program
<
test_conv_bn_relu_pooling
>
();
verify_program
<
test_conv_bn_relu_pooling2
>
();
verify_program
<
test_conv_bn_relu_pooling2
>
();
}
}
test/include/rob.hpp
0 → 100644
View file @
99ee76c0
#ifndef MIGRAPH_GUARD_ROB_HPP
#define MIGRAPH_GUARD_ROB_HPP
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
// Used to access private member variables
template
<
class
Tag
>
struct
stowed
{
static
typename
Tag
::
type
value
;
};
template
<
class
Tag
>
typename
Tag
::
type
stowed
<
Tag
>::
value
;
template
<
class
Tag
,
typename
Tag
::
type
X
>
struct
stow_private
{
stow_private
()
noexcept
{
stowed
<
Tag
>::
value
=
X
;
}
static
stow_private
instance
;
};
template
<
class
Tag
,
typename
Tag
::
type
X
>
stow_private
<
Tag
,
X
>
stow_private
<
Tag
,
X
>::
instance
;
template
<
class
C
,
class
T
>
struct
mem_data_ptr
{
using
type
=
T
C
::*
;
};
#define MIGRAPH_ROB(name, Type, C, mem) \
struct name##_tag : mem_data_ptr<C, Type> \
{ \
}; \
template struct stow_private<name##_tag, &C::mem>; \
template <class T> \
auto& name(T&& x) \
{ \
return x.*stowed<name##_tag>::value; \
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif
test/operation.cpp
View file @
99ee76c0
...
@@ -43,7 +43,9 @@ void operation_copy_test()
...
@@ -43,7 +43,9 @@ void operation_copy_test()
simple_operation
s
{};
simple_operation
s
{};
migraph
::
operation
op1
=
s
;
// NOLINT
migraph
::
operation
op1
=
s
;
// NOLINT
migraph
::
operation
op2
=
op1
;
// NOLINT
migraph
::
operation
op2
=
op1
;
// NOLINT
// cppcheck-suppress duplicateExpression
EXPECT
(
s
.
name
()
==
op1
.
name
());
EXPECT
(
s
.
name
()
==
op1
.
name
());
// cppcheck-suppress duplicateExpression
EXPECT
(
op2
.
name
()
==
op1
.
name
());
EXPECT
(
op2
.
name
()
==
op1
.
name
());
}
}
...
...
test/validate.cpp
View file @
99ee76c0
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraph/instruction.hpp>
#include <migraph/instruction.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <test.hpp>
#include <rob.hpp>
void
simple_test
()
void
simple_test
()
{
{
...
@@ -38,6 +39,11 @@ void incomplete_args()
...
@@ -38,6 +39,11 @@ void incomplete_args()
EXPECT
(
bool
{
p
.
validate
()
==
ins
});
EXPECT
(
bool
{
p
.
validate
()
==
ins
});
}
}
MIGRAPH_ROB
(
access_ins_arguments
,
std
::
vector
<
migraph
::
instruction_ref
>
,
migraph
::
instruction
,
arguments
)
void
invalid_args
()
void
invalid_args
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
@@ -45,7 +51,7 @@ void invalid_args()
...
@@ -45,7 +51,7 @@ void invalid_args()
auto
one
=
p
.
add_literal
(
1
);
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
two
=
p
.
add_literal
(
2
);
auto
ins
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
ins
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
ins
->
arguments
.
clear
();
access_
ins
_
arguments
(
*
ins
)
.
clear
();
EXPECT
(
bool
{
p
.
validate
()
==
p
.
begin
()});
EXPECT
(
bool
{
p
.
validate
()
==
p
.
begin
()});
}
}
...
...
tools/include/operation.hpp
View file @
99ee76c0
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <migraph/shape.hpp>
#include <migraph/shape.hpp>
#include <migraph/rank.hpp>
#include <migraph/argument.hpp>
#include <migraph/argument.hpp>
#include <migraph/context.hpp>
#include <migraph/context.hpp>
#include <migraph/auto_any_cast.hpp>
#include <migraph/auto_any_cast.hpp>
...
@@ -27,13 +28,16 @@ struct operation
...
@@ -27,13 +28,16 @@ struct operation
/// exception.
/// exception.
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
input
)
const
;
/**
/**
* @brief This performs the operation's computation
* @brief This performs the operation's computation.
*
* This method can be optional when the operation is only used as a placeholder to be lowered
* later on.
*
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`.
* `shape` of the `argument`.
* @param input This is the `argument` result from the previous instuction's computation.
* @param input This is the `argument` result from the previous inst
r
uction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
* the same the `output` shape.
*/
*/
...
@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
...
@@ -55,11 +59,29 @@ auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
}
// namespace operation_stream
}
// namespace operation_stream
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPH_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
template
<
class
T
>
argument
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
)
,
output_shape
,
input
);
return
compute
_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
}
<%
<%
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment