Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
31065c7d
Commit
31065c7d
authored
Oct 31, 2022
by
charlie
Browse files
Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test
parents
6bec381f
6acbd4e4
Changes
482
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8287 additions
and
3043 deletions
+8287
-3043
src/api/migraphx.py
src/api/migraphx.py
+13
-0
src/apply_alpha_beta.cpp
src/apply_alpha_beta.cpp
+1
-1
src/common.cpp
src/common.cpp
+101
-14
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+94
-123
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+5122
-1888
src/driver/main.cpp
src/driver/main.cpp
+0
-2
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2811
-960
src/driver/verify.cpp
src/driver/verify.cpp
+2
-1
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+1
-1
src/file_buffer.cpp
src/file_buffer.cpp
+1
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+1
-1
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/argument.hpp
src/include/migraphx/argument.hpp
+1
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+16
-15
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+3
-0
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+69
-3
src/include/migraphx/dyn_output.hpp
src/include/migraphx/dyn_output.hpp
+39
-20
src/include/migraphx/execution_environment.hpp
src/include/migraphx/execution_environment.hpp
+7
-8
No files found.
src/api/migraphx.py
View file @
31065c7d
...
@@ -115,6 +115,7 @@ def shape(h):
...
@@ -115,6 +115,7 @@ def shape(h):
const
=
True
)
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'strides'
,
returns
=
'const std::vector<size_t>&'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'type'
,
returns
=
'migraphx::shape::type_t'
,
const
=
True
)
h
.
method
(
'elements'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'bytes'
,
returns
=
'size_t'
,
const
=
True
)
h
.
method
(
'equal'
,
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::shape&'
),
api
.
params
(
x
=
'const migraphx::shape&'
),
...
@@ -122,6 +123,7 @@ def shape(h):
...
@@ -122,6 +123,7 @@ def shape(h):
returns
=
'bool'
,
returns
=
'bool'
,
const
=
True
)
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'standard'
,
returns
=
'bool'
,
const
=
True
)
h
.
method
(
'index'
,
api
.
params
(
i
=
'size_t'
),
returns
=
'size_t'
,
const
=
True
)
@
auto_handle
()
@
auto_handle
()
...
@@ -274,6 +276,13 @@ def program(h):
...
@@ -274,6 +276,13 @@ def program(h):
params
=
'std::unordered_map<std::string, migraphx::argument>'
),
params
=
'std::unordered_map<std::string, migraphx::argument>'
),
invoke
=
'migraphx::run($@)'
,
invoke
=
'migraphx::run($@)'
,
returns
=
'std::vector<migraphx::argument>'
)
returns
=
'std::vector<migraphx::argument>'
)
h
.
method
(
'run_async'
,
api
.
params
(
params
=
'std::unordered_map<std::string, migraphx::argument>'
,
s
=
'void*'
,
name
=
'const char *'
),
invoke
=
'migraphx::run_async($@)'
,
returns
=
'std::vector<migraphx::argument>'
)
h
.
method
(
'equal'
,
h
.
method
(
'equal'
,
api
.
params
(
x
=
'const migraphx::program&'
),
api
.
params
(
x
=
'const migraphx::program&'
),
invoke
=
'migraphx::equal($@)'
,
invoke
=
'migraphx::equal($@)'
,
...
@@ -450,4 +459,8 @@ def experimental_custom_op(h):
...
@@ -450,4 +459,8 @@ def experimental_custom_op(h):
h
.
virtual
(
'compute_shape'
,
h
.
virtual
(
'compute_shape'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'migraphx::shape'
)
returns
=
'migraphx::shape'
)
h
.
virtual
(
'output_alias'
,
api
.
params
(
inputs
=
'std::vector<migraphx::shape>'
),
returns
=
'std::vector<size_t>'
)
h
.
virtual
(
'runs_on_offload_target'
,
returns
=
'bool'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
h
.
method
(
'register'
,
invoke
=
'migraphx::register_custom_op($@)'
)
src/apply_alpha_beta.cpp
View file @
31065c7d
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto
a
=
args
[
0
];
auto
a
=
args
[
0
];
auto
b
=
args
[
1
];
auto
b
=
args
[
1
];
auto
input_type
=
a
->
get_shape
().
type
();
auto
input_type
=
a
->
get_shape
().
type
();
if
(
!
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
if
(
not
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
{
{
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
...
...
src/common.cpp
View file @
31065c7d
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
// output_lens = (3,2,7,5)
//
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
std
::
vector
<
std
::
size_t
>
s1
)
{
{
...
@@ -50,25 +52,67 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -50,25 +52,67 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
s0
;
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
s0
.
swap
(
s1
);
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
migraphx
::
to_string_range
(
s0
)
+
to_string_range
(
s1
)
+
"} mismatch!"
);
"} and {"
+
migraphx
::
to_string_range
(
s1
)
+
"} mismatch!"
);
}
}
return
std
::
max
(
a
,
b
);
return
std
::
max
(
a
,
b
);
});
});
return
out_lens
;
return
out_lens
;
}
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
assert
(
s0
.
dynamic
()
or
s1
.
dynamic
());
// change both shapes to dynamic_dimension representation
if
(
not
s0
.
dynamic
())
s0
=
s0
.
to_dynamic
();
if
(
not
s1
.
dynamic
())
s1
=
s1
.
to_dynamic
();
if
(
s0
.
ndim
()
>
s1
.
ndim
())
{
std
::
swap
(
s0
,
s1
);
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
shape
::
dynamic_dimension
one_dyn_dim
{
1
,
1
,
0
};
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
s1
.
dyn_dims
().
cbegin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
(
a
==
one_dyn_dim
or
b
==
one_dyn_dim
)
{
// setting opt to 0, may need to be changed
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
0
};
}
else
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s0
.
dyn_dims
())
+
"} and {"
+
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
}
});
return
out_dims
;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
{
{
assert
(
not
shapes
.
empty
());
assert
(
not
shapes
.
empty
());
assert
(
std
::
none_of
(
shapes
.
cbegin
(),
shapes
.
cend
(),
[](
auto
shape
)
{
return
shape
.
dynamic
();
}));
return
transform_accumulate
(
shapes
.
begin
()
+
1
,
return
transform_accumulate
(
shapes
.
begin
()
+
1
,
shapes
.
end
(),
shapes
.
end
(),
shapes
.
front
().
lens
(),
shapes
.
front
().
lens
(),
...
@@ -114,20 +158,63 @@ instruction_ref insert_common_op(module& m,
...
@@ -114,20 +158,63 @@ instruction_ref insert_common_op(module& m,
const
operation
&
op
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
std
::
vector
<
instruction_ref
>
inputs
)
{
{
auto
common
=
common_shape
(
to_shapes
(
inputs
));
if
(
std
::
any_of
(
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
->
get_shape
().
dynamic
();
}))
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
{
// currently only handles the binary case
if
(
inputs
.
size
()
!=
2
)
{
{
input
=
m
.
insert_instruction
(
MIGRAPHX_THROW
(
"INSERT_COMMON_OP: not handled; "
+
migraphx
::
to_string
(
inputs
.
size
())
+
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
common
.
lens
()}}),
input
);
"inputs, only handle two inputs if any are dynamic shape"
);
}
}
if
(
input
->
get_shape
().
type
()
!=
common
.
type
())
auto
c_type
=
compute_common_types
(
to_shapes
(
inputs
));
auto
c_dyn_dims
=
compute_broadcasted_dyn_dims
(
inputs
[
0
]
->
get_shape
(),
inputs
[
1
]
->
get_shape
());
// following should work for a static or dynamic shape
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
{
input
=
m
.
insert_instruction
(
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
common
.
type
()}}),
input
);
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
0
],
inputs
[
1
]);
}
}
return
input
;
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
});
{
inputs
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
1
],
inputs
[
0
]);
}
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
{
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
c_type
}}),
input
);
}
return
input
;
});
}
else
{
auto
common
=
common_shape
(
to_shapes
(
inputs
));
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
{
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
common
.
lens
()}}),
input
);
}
if
(
input
->
get_shape
().
type
()
!=
common
.
type
())
{
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
common
.
type
()}}),
input
);
}
return
input
;
});
}
return
m
.
insert_instruction
(
ins
,
op
,
inputs
);
return
m
.
insert_instruction
(
ins
,
op
,
inputs
);
}
}
...
...
src/driver/alexnet.cpp
View file @
31065c7d
...
@@ -25,13 +25,10 @@
...
@@ -25,13 +25,10 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/json.hpp>
#include "models.hpp"
#include "models.hpp"
namespace
migraphx
{
namespace
migraphx
{
namespace
driver
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
migraphx
::
program
alexnet
(
unsigned
batch
)
// NOLINT(readability-function-size)
migraphx
::
program
alexnet
(
unsigned
batch
)
// NOLINT(readability-function-size)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
1
)));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
1
)));
auto
x_main_module_2
=
mmain
->
add_literal
(
migraphx
::
abs
(
auto
x_main_module_2
=
mmain
->
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
2
)));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
2
)));
auto
x_
input_1
=
mmain
->
add_parameter
(
auto
x_
0
=
mmain
->
add_parameter
(
"
input.1
"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch
,
3
,
224
,
224
}});
"
0
"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch
,
3
,
224
,
224
}});
auto
x_main_module_4
=
mmain
->
add_literal
(
auto
x_main_module_4
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
4096
}},
3
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
}},
3
));
auto
x_main_module_5
=
mmain
->
add_literal
(
auto
x_main_module_5
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
4
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
,
4096
}},
4
));
auto
x_main_module_6
=
mmain
->
add_literal
(
auto
x_main_module_6
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
9216
}},
5
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
5
));
auto
x_main_module_7
=
mmain
->
add_literal
(
auto
x_main_module_7
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
6
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
4096
}},
6
));
auto
x_main_module_8
=
mmain
->
add_literal
(
auto
x_main_module_8
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
,
4096
}},
7
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
7
));
auto
x_main_module_9
=
mmain
->
add_literal
(
auto
x_main_module_9
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
}},
8
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
9216
}},
8
));
auto
x_main_module_10
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_10
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
384
,
3
,
3
}},
9
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
9
));
auto
x_main_module_11
=
mmain
->
add_literal
(
auto
x_main_module_11
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
10
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
,
3
,
3
}},
10
));
auto
x_main_module_12
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_12
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
,
192
,
3
,
3
}},
11
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
11
));
auto
x_main_module_13
=
mmain
->
add_literal
(
auto
x_main_module_13
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
}},
12
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
384
,
3
,
3
}},
12
));
auto
x_main_module_14
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_14
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
,
64
,
5
,
5
}},
13
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
}},
13
));
auto
x_main_module_15
=
mmain
->
add_literal
(
auto
x_main_module_15
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
}},
14
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
,
192
,
3
,
3
}},
14
));
auto
x_main_module_16
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_16
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
,
3
,
3
}},
15
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
}},
15
));
auto
x_main_module_17
=
mmain
->
add_literal
(
auto
x_main_module_17
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
16
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
,
64
,
5
,
5
}},
16
));
auto
x_main_module_18
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_18
=
mmain
->
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
17
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
}},
17
));
auto
x_main_module_19
=
mmain
->
add_literal
(
auto
x_main_module_19
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
}},
18
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
18
));
auto
x_main_module_20
=
mmain
->
add_instruction
(
auto
x_main_module_20
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"convolution"
,
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"
),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"
)),
x_0
,
x_input_1
,
x_main_module_18
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,64,55,55]}"
)),
x_main_module_19
);
x_main_module_19
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,64,55,55]}"
),
x_main_module_18
);
auto
x_main_module_22
=
auto
x_main_module_22
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_20
,
x_main_module_21
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_20
,
x_main_module_21
);
auto
x_main_module_23
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_22
);
auto
x_main_module_23
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_22
);
auto
x_main_module_24
=
mmain
->
add_instruction
(
auto
x_main_module_24
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
x_main_module_23
);
x_main_module_23
);
auto
x_main_module_25
=
mmain
->
add_instruction
(
auto
x_main_module_25
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"convolution"
,
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"
),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"
)),
x_main_module_24
,
x_main_module_24
,
x_main_module_1
4
);
x_main_module_1
7
);
auto
x_main_module_26
=
mmain
->
add_instruction
(
auto
x_main_module_26
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,192,27,27]}"
),
x_main_module_16
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,192,27,27]}"
)),
x_main_module_15
);
auto
x_main_module_27
=
auto
x_main_module_27
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_25
,
x_main_module_26
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_25
,
x_main_module_26
);
auto
x_main_module_28
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_27
);
auto
x_main_module_28
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_27
);
auto
x_main_module_29
=
mmain
->
add_instruction
(
auto
x_main_module_29
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
x_main_module_28
);
x_main_module_28
);
auto
x_main_module_30
=
mmain
->
add_instruction
(
auto
x_main_module_30
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"convolution"
,
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
x_main_module_29
,
x_main_module_29
,
x_main_module_1
2
);
x_main_module_1
5
);
auto
x_main_module_31
=
mmain
->
add_instruction
(
auto
x_main_module_31
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,384,13,13]}"
),
x_main_module_14
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,384,13,13]}"
)),
x_main_module_13
);
auto
x_main_module_32
=
auto
x_main_module_32
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_34
=
mmain
->
add_instruction
(
auto
x_main_module_34
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"convolution"
,
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
x_main_module_33
,
x_main_module_33
,
x_main_module_1
0
);
x_main_module_1
3
);
auto
x_main_module_35
=
mmain
->
add_instruction
(
auto
x_main_module_35
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,256,13,13]}"
),
x_main_module_12
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,256,13,13]}"
)),
x_main_module_11
);
auto
x_main_module_36
=
auto
x_main_module_36
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_38
=
mmain
->
add_instruction
(
auto
x_main_module_38
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"convolution"
,
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
x_main_module_37
,
x_main_module_37
,
x_main_module_1
6
);
x_main_module_1
1
);
auto
x_main_module_39
=
mmain
->
add_instruction
(
auto
x_main_module_39
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,256,13,13]}"
),
x_main_module_10
);
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,256,13,13]}"
)),
x_main_module_17
);
auto
x_main_module_40
=
auto
x_main_module_40
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_38
,
x_main_module_39
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_38
,
x_main_module_39
);
auto
x_main_module_41
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_40
);
auto
x_main_module_41
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_40
);
auto
x_main_module_42
=
mmain
->
add_instruction
(
auto
x_main_module_42
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
x_main_module_41
);
x_main_module_41
);
auto
x_main_module_43
=
mmain
->
add_instruction
(
auto
x_main_module_43
=
migraphx
::
make_op
(
"reshape"
,
migraphx
::
from_json_string
(
"{dims:[1,9216]}"
)),
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"flatten"
,
"{axis:1}"
),
x_main_module_42
);
x_main_module_42
);
auto
x_main_module_44
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_main_module_43
);
auto
x_main_module_44
=
mmain
->
add_instruction
(
auto
x_main_module_45
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_9
);
x_main_module_6
);
auto
x_main_module_46
=
auto
x_main_module_45
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_44
,
x_main_module_45
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_43
,
x_main_module_44
);
auto
x_main_module_46
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
x_main_module_7
);
auto
x_main_module_47
=
mmain
->
add_instruction
(
auto
x_main_module_47
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_8
);
x_main_module_2
);
auto
x_main_module_48
=
mmain
->
add_instruction
(
auto
x_main_module_48
=
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_2
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_46
,
x_main_module_47
);
auto
x_main_module_49
=
auto
x_main_module_49
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_45
,
x_main_module_48
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_47
,
x_main_module_48
);
auto
x_main_module_50
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_49
);
auto
x_main_module_50
=
auto
x_main_module_51
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_46
,
x_main_module_49
);
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
auto
x_main_module_51
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_50
);
x_main_module_4
);
auto
x_main_module_52
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_main_module_51
);
auto
x_main_module_52
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_50
,
x_main_module_51
);
auto
x_main_module_53
=
mmain
->
add_instruction
(
auto
x_main_module_53
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_7
);
x_main_module_5
);
auto
x_main_module_54
=
auto
x_main_module_54
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_52
,
x_main_module_53
);
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
auto
x_main_module_55
=
mmain
->
add_instruction
(
x_main_module_1
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_6
);
auto
x_main_module_55
=
auto
x_main_module_56
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_53
,
x_main_module_54
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_1
);
auto
x_main_module_56
=
auto
x_main_module_57
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_52
,
x_main_module_55
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_55
,
x_main_module_56
);
auto
x_main_module_57
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_56
);
auto
x_main_module_58
=
auto
x_main_module_58
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_54
,
x_main_module_57
);
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
auto
x_main_module_59
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_58
);
x_main_module_8
);
auto
x_main_module_59
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_57
,
x_main_module_58
);
auto
x_main_module_60
=
mmain
->
add_instruction
(
auto
x_main_module_60
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,1000]}"
)),
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_5
);
x_main_module_9
);
auto
x_main_module_61
=
auto
x_main_module_61
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_59
,
x_main_module_60
);
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,1000]}"
)),
auto
x_main_module_62
=
mmain
->
add_instruction
(
x_main_module_0
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,1000]}"
),
x_main_module_4
);
auto
x_main_module_62
=
auto
x_main_module_63
=
mmain
->
add_instruction
(
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_60
,
x_main_module_61
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,1000]}"
),
x_main_module_0
);
auto
x_main_module_63
=
auto
x_main_module_64
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_59
,
x_main_module_62
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_62
,
x_main_module_63
);
mmain
->
add_return
({
x_main_module_63
});
auto
x_main_module_65
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_61
,
x_main_module_64
);
mmain
->
add_return
({
x_main_module_65
});
return
p
;
return
p
;
}
}
...
...
src/driver/inceptionv3.cpp
View file @
31065c7d
This diff is collapsed.
Click to expand it.
src/driver/main.cpp
View file @
31065c7d
...
@@ -44,7 +44,6 @@
...
@@ -44,7 +44,6 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
...
@@ -221,7 +220,6 @@ struct loader
...
@@ -221,7 +220,6 @@ struct loader
{
{
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
{
{
migraphx
::
rewrite_batchnorm
{},
migraphx
::
eliminate_identity
{},
migraphx
::
eliminate_identity
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_algebra
{},
migraphx
::
simplify_algebra
{},
...
...
src/driver/resnet50.cpp
View file @
31065c7d
This diff is collapsed.
Click to expand it.
src/driver/verify.cpp
View file @
31065c7d
...
@@ -145,7 +145,7 @@ void verify_reduced(program p,
...
@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
}
}
...
@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
...
@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{
{
const
auto
*
mm
=
p
.
get_main_module
();
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
{
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
...
...
src/eliminate_concat.cpp
View file @
31065c7d
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
if
(
axis_index
==
0
||
if
(
axis_index
==
0
or
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
{
{
// Last input should be an allocation
// Last input should be an allocation
...
...
src/eliminate_contiguous.cpp
View file @
31065c7d
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
!
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
{
{
return
false
;
return
false
;
}
}
...
...
src/file_buffer.cpp
View file @
31065c7d
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
T
buffer
(
size
,
0
);
if
(
!
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
return
buffer
;
}
}
...
...
src/fuse_pointwise.cpp
View file @
31065c7d
...
@@ -39,7 +39,7 @@ static literal get_scalar(instruction_ref ins)
...
@@ -39,7 +39,7 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
not
(
s
.
elements
()
=
=
1
or
s
.
scalar
()))
if
(
s
.
elements
()
!
=
1
&&
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
...
...
src/include/migraphx/allocation_model.hpp
View file @
31065c7d
...
@@ -205,7 +205,7 @@ struct allocation_model
...
@@ -205,7 +205,7 @@ struct allocation_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -267,7 +267,7 @@ struct allocation_model
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/argument.hpp
View file @
31065c7d
...
@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
...
@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t
m_data
{};
data_t
m_data
{};
};
};
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
argument
>&
args
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_from_value
(
const
value
&
v
,
argument
&
a
);
void
migraphx_from_value
(
const
value
&
v
,
argument
&
a
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
31065c7d
...
@@ -101,7 +101,7 @@ struct check_shapes
...
@@ -101,7 +101,7 @@ struct check_shapes
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -164,7 +164,7 @@ struct check_shapes
...
@@ -164,7 +164,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_shape
()
const
const
check_shapes
&
same_shape
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -174,7 +174,7 @@ struct check_shapes
...
@@ -174,7 +174,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_type
()
const
const
check_shapes
&
same_type
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -184,10 +184,10 @@ struct check_shapes
...
@@ -184,10 +184,10 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_dims
()
const
const
check_shapes
&
same_dims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -197,7 +197,7 @@ struct check_shapes
...
@@ -197,7 +197,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_ndims
()
const
const
check_shapes
&
same_ndims
()
const
{
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -207,7 +207,7 @@ struct check_shapes
...
@@ -207,7 +207,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
standard
()
const
const
check_shapes
&
standard
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -217,7 +217,7 @@ struct check_shapes
...
@@ -217,7 +217,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
standard_or_scalar
()
const
const
check_shapes
&
standard_or_scalar
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -227,7 +227,7 @@ struct check_shapes
...
@@ -227,7 +227,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
packed
()
const
const
check_shapes
&
packed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -237,7 +237,7 @@ struct check_shapes
...
@@ -237,7 +237,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
packed_or_broadcasted
()
const
const
check_shapes
&
packed_or_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -247,7 +247,7 @@ struct check_shapes
...
@@ -247,7 +247,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
tuple_type
()
const
const
check_shapes
&
tuple_type
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -257,7 +257,7 @@ struct check_shapes
...
@@ -257,7 +257,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
not_transposed
()
const
const
check_shapes
&
not_transposed
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -267,7 +267,7 @@ struct check_shapes
...
@@ -267,7 +267,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
not_broadcasted
()
const
const
check_shapes
&
not_broadcasted
()
const
{
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -278,7 +278,7 @@ struct check_shapes
...
@@ -278,7 +278,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
return
*
this
;
}
}
...
@@ -288,7 +288,8 @@ struct check_shapes
...
@@ -288,7 +288,8 @@ struct check_shapes
*/
*/
const
check_shapes
&
batch_not_transposed
()
const
const
check_shapes
&
batch_not_transposed
()
const
{
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
not
this
->
all_of
(
[
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
return
*
this
;
return
*
this
;
}
}
...
...
src/include/migraphx/common.hpp
View file @
31065c7d
...
@@ -36,6 +36,9 @@ struct operation;
...
@@ -36,6 +36,9 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/concat_opt.hpp
View file @
31065c7d
...
@@ -183,7 +183,7 @@ struct concat_optimization
...
@@ -183,7 +183,7 @@ struct concat_optimization
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
:
private_detail_te_value
(
std
::
move
(
value
))
{
{
...
@@ -233,7 +233,7 @@ struct concat_optimization
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/include/migraphx/context.hpp
View file @
31065c7d
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
{
return
{};
return
{};
}
}
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
{
}
template
<
class
T
>
void
finish_on_context
(
T
&
,
any_ptr
)
{
}
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
...
@@ -78,6 +87,10 @@ struct context
...
@@ -78,6 +87,10 @@ struct context
void
from_value
(
const
value
&
v
);
void
from_value
(
const
value
&
v
);
// (optional)
// (optional)
any_ptr
get_queue
();
any_ptr
get_queue
();
// (optional)
void
wait_for
(
any_ptr
queue
);
// (optional)
void
finish_on
(
any_ptr
queue
);
//
//
void
finish
()
const
;
void
finish
()
const
;
};
};
...
@@ -165,6 +178,18 @@ struct context
...
@@ -165,6 +178,18 @@ struct context
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
}
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
void
finish
()
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -187,6 +212,8 @@ struct context
...
@@ -187,6 +212,8 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
wait_for
(
any_ptr
queue
)
=
0
;
virtual
void
finish_on
(
any_ptr
queue
)
=
0
;
virtual
void
finish
()
const
=
0
;
virtual
void
finish
()
const
=
0
;
};
};
...
@@ -231,6 +258,33 @@ struct context
...
@@ -231,6 +258,33 @@ struct context
return
get_queue_context
(
private_detail_te_self
);
return
get_queue_context
(
private_detail_te_self
);
}
}
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
{
...
@@ -246,9 +300,9 @@ struct context
...
@@ -246,9 +300,9 @@ struct context
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
)
)
:
private_detail_te_value
(
value
)
{
{
}
}
...
@@ -277,6 +331,18 @@ struct context
...
@@ -277,6 +331,18 @@ struct context
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
}
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
...
@@ -306,7 +372,7 @@ struct context
...
@@ -306,7 +372,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
return
*
private_detail_te_handle_mem_var
;
}
}
...
...
src/
targets/gpu/device/conca
t.
c
pp
→
src/
include/migraphx/dyn_outpu
t.
h
pp
View file @
31065c7d
...
@@ -21,36 +21,55 @@
...
@@ -21,36 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
concat
(
hipStream_t
stream
,
struct
dyn_output
const
migraphx
::
shape
&
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
std
::
size_t
>
offsets
)
{
{
auto
ninputs
=
args
.
size
()
-
1
;
// original shape from the instruction
for
(
std
::
size_t
j
=
0
;
j
<
ninputs
;
j
++
)
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template
<
class
F
>
struct
compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
()
const
{
{
auto
&&
arg
=
args
[
j
];
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
auto
offset
=
offsets
[
j
];
if
(
ins_shape
.
dynamic
())
auto
byte_offset
=
offset
*
arg
.
get_shape
().
type_size
();
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
auto
output_shape
=
shape
{
return
dyn_output
{
ins_shape
,
ins_shape
};
arg
.
get_shape
().
type
(),
arg
.
get_shape
().
lens
(),
args
.
back
().
get_shape
().
strides
()};
});
auto
output
=
argument
{
output_shape
,
args
.
back
().
data
()
+
byte_offset
};
contiguous
(
stream
,
output
,
arg
);
}
}
return
args
.
back
();
operator
shape
()
const
{
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
}
};
template
<
class
F
>
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
{
return
{
f
};
}
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif
src/
targets/gpu/
include/migraphx/
gpu/acosh
.hpp
→
src/include/migraphx/
execution_environment
.hpp
View file @
31065c7d
...
@@ -21,22 +21,21 @@
...
@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_ACOSH
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_ACOSH
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/any_ptr.hpp>
#include <migraphx/gpu/device/acosh.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
hip_acosh
:
unary_device
<
hip_acosh
,
device
::
acosh
>
struct
execution_environment
{
{
any_ptr
queue
=
any_ptr
{};
bool
async
=
false
;
};
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
/* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
Prev
1
2
3
4
5
6
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment