Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
870a396b
Commit
870a396b
authored
Jan 23, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
228b665c
d309e02f
Changes
473
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8064 additions
and
3874 deletions
+8064
-3874
src/common.cpp
src/common.cpp
+96
-14
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+2
-2
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+15
-15
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+5003
-2465
src/driver/main.cpp
src/driver/main.cpp
+8
-4
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+2744
-1313
src/driver/verify.cpp
src/driver/verify.cpp
+2
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+17
-4
src/file_buffer.cpp
src/file_buffer.cpp
+16
-8
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+11
-2
src/include/migraphx/argument.hpp
src/include/migraphx/argument.hpp
+1
-0
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+1
-1
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+3
-0
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+67
-1
src/include/migraphx/dyn_output.hpp
src/include/migraphx/dyn_output.hpp
+39
-20
src/include/migraphx/execution_environment.hpp
src/include/migraphx/execution_environment.hpp
+7
-8
src/include/migraphx/file_buffer.hpp
src/include/migraphx/file_buffer.hpp
+1
-1
src/include/migraphx/instruction.hpp
src/include/migraphx/instruction.hpp
+2
-0
src/include/migraphx/layout_nhwc.hpp
src/include/migraphx/layout_nhwc.hpp
+23
-0
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+6
-15
No files found.
src/common.cpp
View file @
870a396b
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
// output_lens = (3,2,7,5)
//
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
std
::
vector
<
std
::
size_t
>
s1
)
{
{
...
@@ -50,25 +52,62 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -50,25 +52,62 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
s0
;
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
s0
.
swap
(
s1
);
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
migraphx
::
to_string_range
(
s0
)
+
to_string_range
(
s1
)
+
"} mismatch!"
);
"} and {"
+
migraphx
::
to_string_range
(
s1
)
+
"} mismatch!"
);
}
}
return
std
::
max
(
a
,
b
);
return
std
::
max
(
a
,
b
);
});
});
return
out_lens
;
return
out_lens
;
}
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
// change both shapes to dynamic_dimension representation
s0
=
s0
.
to_dynamic
();
s1
=
s1
.
to_dynamic
();
if
(
s0
.
ndim
()
>
s1
.
ndim
())
{
std
::
swap
(
s0
,
s1
);
}
auto
offset
=
s1
.
ndim
()
-
s0
.
ndim
();
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
std
::
transform
(
s0
.
dyn_dims
().
cbegin
(),
s0
.
dyn_dims
().
cend
(),
s1
.
dyn_dims
().
cbegin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
(
a
==
1
or
b
==
1
)
{
// setting opt to 0, may need to be changed
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
0
};
}
else
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s0
.
dyn_dims
())
+
"} and {"
+
migraphx
::
to_string_range
(
s1
.
dyn_dims
())
+
"} mismatch!"
);
}
});
return
out_dims
;
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
{
{
assert
(
not
shapes
.
empty
());
assert
(
not
shapes
.
empty
());
assert
(
std
::
none_of
(
shapes
.
cbegin
(),
shapes
.
cend
(),
[](
auto
shape
)
{
return
shape
.
dynamic
();
}));
return
transform_accumulate
(
shapes
.
begin
()
+
1
,
return
transform_accumulate
(
shapes
.
begin
()
+
1
,
shapes
.
end
(),
shapes
.
end
(),
shapes
.
front
().
lens
(),
shapes
.
front
().
lens
(),
...
@@ -114,20 +153,63 @@ instruction_ref insert_common_op(module& m,
...
@@ -114,20 +153,63 @@ instruction_ref insert_common_op(module& m,
const
operation
&
op
,
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
inputs
)
std
::
vector
<
instruction_ref
>
inputs
)
{
{
auto
common
=
common_shape
(
to_shapes
(
inputs
));
if
(
std
::
any_of
(
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
inputs
.
cbegin
(),
inputs
.
cend
(),
[](
auto
input
)
{
return
input
->
get_shape
().
dynamic
();
}))
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
{
// currently only handles the binary case
if
(
inputs
.
size
()
!=
2
)
{
{
input
=
m
.
insert_instruction
(
MIGRAPHX_THROW
(
"INSERT_COMMON_OP: not handled; "
+
migraphx
::
to_string
(
inputs
.
size
())
+
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
common
.
lens
()}}),
input
);
"inputs, only handle two inputs if any are dynamic shape"
);
}
}
if
(
input
->
get_shape
().
type
()
!=
common
.
type
())
auto
c_type
=
compute_common_types
(
to_shapes
(
inputs
));
auto
c_dyn_dims
=
compute_broadcasted_dyn_dims
(
inputs
[
0
]
->
get_shape
(),
inputs
[
1
]
->
get_shape
());
// following should work for a static or dynamic shape
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
{
input
=
m
.
insert_instruction
(
inputs
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
common
.
type
()}}),
input
);
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
0
],
inputs
[
1
]);
}
}
return
input
;
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
});
{
inputs
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)}}),
inputs
[
1
],
inputs
[
0
]);
}
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
{
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
c_type
}}),
input
);
}
return
input
;
});
}
else
{
auto
common
=
common_shape
(
to_shapes
(
inputs
));
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
{
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
common
.
lens
()}}),
input
);
}
if
(
input
->
get_shape
().
type
()
!=
common
.
type
())
{
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
common
.
type
()}}),
input
);
}
return
input
;
});
}
return
m
.
insert_instruction
(
ins
,
op
,
inputs
);
return
m
.
insert_instruction
(
ins
,
op
,
inputs
);
}
}
...
...
src/dead_code_elimination.cpp
View file @
870a396b
...
@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
...
@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
// identity, allocate]
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
i
->
name
().
front
()
!
=
'@'
and
not
(
i
->
name
().
front
()
=
=
'@'
)
and
not
contains
({
"identity"
,
"allocate"
},
i
->
name
())
and
not
contains
({
"undefined"
,
"identity"
,
"allocate"
},
i
->
name
()
))
not
i
->
is_undefined
(
))
continue
;
continue
;
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
std
::
unordered_set
<
instruction_ref
>
visited
;
std
::
unordered_set
<
instruction_ref
>
visited
;
...
...
src/driver/alexnet.cpp
View file @
870a396b
...
@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -74,9 +74,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
auto
x_main_module_19
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
auto
x_main_module_19
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
18
));
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
18
));
auto
x_main_module_20
=
mmain
->
add_instruction
(
auto
x_main_module_20
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"convolution"
,
migraphx
::
make_json_op
(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
"convolution"
,
"4],use_dynamic_same_auto_pad:0
}"
),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]
}"
),
x_0
,
x_0
,
x_main_module_19
);
x_main_module_19
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
auto
x_main_module_21
=
mmain
->
add_instruction
(
...
@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -90,9 +90,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
x_main_module_23
);
x_main_module_23
);
auto
x_main_module_25
=
mmain
->
add_instruction
(
auto
x_main_module_25
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"convolution"
,
migraphx
::
make_json_op
(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
"convolution"
,
"1],use_dynamic_same_auto_pad:0
}"
),
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]
}"
),
x_main_module_24
,
x_main_module_24
,
x_main_module_17
);
x_main_module_17
);
auto
x_main_module_26
=
mmain
->
add_instruction
(
auto
x_main_module_26
=
mmain
->
add_instruction
(
...
@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -106,9 +106,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
x_main_module_28
);
x_main_module_28
);
auto
x_main_module_30
=
mmain
->
add_instruction
(
auto
x_main_module_30
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"convolution"
,
migraphx
::
make_json_op
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"convolution"
,
"1],use_dynamic_same_auto_pad:0
}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]
}"
),
x_main_module_29
,
x_main_module_29
,
x_main_module_15
);
x_main_module_15
);
auto
x_main_module_31
=
mmain
->
add_instruction
(
auto
x_main_module_31
=
mmain
->
add_instruction
(
...
@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -117,9 +117,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_34
=
mmain
->
add_instruction
(
auto
x_main_module_34
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"convolution"
,
migraphx
::
make_json_op
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"convolution"
,
"1],use_dynamic_same_auto_pad:0
}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]
}"
),
x_main_module_33
,
x_main_module_33
,
x_main_module_13
);
x_main_module_13
);
auto
x_main_module_35
=
mmain
->
add_instruction
(
auto
x_main_module_35
=
mmain
->
add_instruction
(
...
@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
...
@@ -128,9 +128,9 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_38
=
mmain
->
add_instruction
(
auto
x_main_module_38
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"convolution"
,
migraphx
::
make_json_op
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"convolution"
,
"1],use_dynamic_same_auto_pad:0
}"
),
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]
}"
),
x_main_module_37
,
x_main_module_37
,
x_main_module_11
);
x_main_module_11
);
auto
x_main_module_39
=
mmain
->
add_instruction
(
auto
x_main_module_39
=
mmain
->
add_instruction
(
...
...
src/driver/inceptionv3.cpp
View file @
870a396b
This diff is collapsed.
Click to expand it.
src/driver/main.cpp
View file @
870a396b
...
@@ -44,7 +44,6 @@
...
@@ -44,7 +44,6 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
...
@@ -110,8 +109,12 @@ struct loader
...
@@ -110,8 +109,12 @@ struct loader
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
brief
,
{
"--brief"
},
ap
.
help
(
"Make the output brief."
),
ap
.
set_value
(
true
));
ap
(
output_type
,
ap
(
output_type
,
{
"--cpp"
},
{
"--cpp"
},
ap
.
help
(
"Print out the program as
cpp
program."
),
ap
.
help
(
"Print out the program as
C++
program."
),
ap
.
set_value
(
"cpp"
));
ap
.
set_value
(
"cpp"
));
ap
(
output_type
,
{
"--python"
,
"--py"
},
ap
.
help
(
"Print out the program as python program."
),
ap
.
set_value
(
"py"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
{
"--json"
},
ap
.
help
(
"Print out program as json."
),
ap
.
set_value
(
"json"
));
ap
(
output_type
,
ap
(
output_type
,
{
"--text"
},
{
"--text"
},
...
@@ -221,7 +224,6 @@ struct loader
...
@@ -221,7 +224,6 @@ struct loader
{
{
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
{
{
migraphx
::
rewrite_batchnorm
{},
migraphx
::
eliminate_identity
{},
migraphx
::
eliminate_identity
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_algebra
{},
migraphx
::
simplify_algebra
{},
...
@@ -261,7 +263,9 @@ struct loader
...
@@ -261,7 +263,9 @@ struct loader
type
=
"binary"
;
type
=
"binary"
;
}
}
if
(
type
==
"cpp"
)
if
(
type
==
"py"
)
p
.
print_py
(
*
os
);
else
if
(
type
==
"cpp"
)
p
.
print_cpp
(
*
os
);
p
.
print_cpp
(
*
os
);
else
if
(
type
==
"graphviz"
)
else
if
(
type
==
"graphviz"
)
p
.
print_graph
(
*
os
,
brief
);
p
.
print_graph
(
*
os
,
brief
);
...
...
src/driver/resnet50.cpp
View file @
870a396b
This diff is collapsed.
Click to expand it.
src/driver/verify.cpp
View file @
870a396b
...
@@ -145,7 +145,7 @@ void verify_reduced(program p,
...
@@ -145,7 +145,7 @@ void verify_reduced(program p,
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
auto
last
=
std
::
prev
(
mm
->
end
(),
n
+
1
);
mm
->
remove_instructions
(
last
,
mm
->
end
());
mm
->
remove_instructions
(
last
,
mm
->
end
());
std
::
cout
<<
"Verify: "
<<
std
::
endl
;
std
::
cout
<<
"Verify: "
<<
n
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
std
::
cout
<<
p
<<
std
::
endl
;
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
verify_program
(
std
::
to_string
(
n
),
p
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
}
}
...
@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
...
@@ -159,6 +159,7 @@ void verify_reduced_program(const program& p,
{
{
const
auto
*
mm
=
p
.
get_main_module
();
const
auto
*
mm
=
p
.
get_main_module
();
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
auto
n
=
std
::
distance
(
mm
->
begin
(),
mm
->
end
());
std
::
cout
<<
"Verify steps: "
<<
n
<<
std
::
endl
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
{
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
verify_reduced
(
p
,
i
,
t
,
options
,
quantize
,
inputs
,
tolerance
);
...
...
src/eliminate_contiguous.cpp
View file @
870a396b
...
@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try
try
{
{
shape
new_shape
=
ins
->
get_operator
().
compute_shape
(
inputs
,
mods
);
shape
new_shape
=
ins
->
get_operator
().
compute_shape
(
inputs
,
mods
);
// Cannot tell if a dynamic shape will need to be made contiguous
if
(
new_shape
.
dynamic
())
{
return
false
;
}
// If the output shape is a standard shape, no need to try its output
// If the output shape is a standard shape, no need to try its output
if
(
new_shape
.
standard
())
if
(
new_shape
.
standard
())
{
{
...
@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
...
@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}
}
}
// Perform evaluations in parallel
// Perform
static contiguous
evaluations in parallel
std
::
vector
<
argument
>
literals
(
const_instructions
.
size
());
std
::
vector
<
argument
>
literals
(
const_instructions
.
size
());
par_for
(
const_instructions
.
size
(),
1
,
[
&
](
const
auto
i
)
{
par_for
(
const_instructions
.
size
(),
1
,
[
&
](
const
auto
i
)
{
auto
c
=
op
::
contiguous
{};
auto
c
=
op
::
contiguous
{};
auto
prev
=
const_instructions
[
i
]
->
inputs
().
front
();
auto
prev
=
const_instructions
[
i
]
->
inputs
().
front
();
literals
[
i
]
=
c
.
compute
(
c
.
compute_shape
({
prev
->
get_shape
()}),
{
prev
->
eval
()});
// compute the output contiguous shape from the previous instruction shape
shape
computed_shape
=
c
.
compute_shape
({
prev
->
get_shape
()});
const
std
::
vector
<
argument
>&
prev_eval
=
{
prev
->
eval
()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto
co_shape
=
make_compute_output_shape
(
pack
(
c
,
computed_shape
,
prev_eval
));
literals
[
i
]
=
c
.
compute
(
co_shape
,
prev_eval
);
});
});
// Replace static contiguous operations with a literal
for
(
size_t
i
=
0
;
i
<
const_instructions
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
const_instructions
.
size
();
i
++
)
{
{
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
auto
l
=
m
.
add_literal
(
literals
[
i
].
get_shape
(),
literals
[
i
].
data
());
...
...
src/file_buffer.cpp
View file @
870a396b
...
@@ -30,23 +30,31 @@ namespace migraphx {
...
@@ -30,23 +30,31 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
T
>
template
<
class
T
>
T
generic_read_file
(
const
std
::
string
&
filename
)
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
{
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
streamsize
size
=
is
.
tellg
();
if
(
nbytes
==
0
)
if
(
size
<
1
)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes
=
is
.
tellg
();
if
(
offset
>
nbytes
)
MIGRAPHX_THROW
(
"offset is larger than file size"
);
nbytes
-=
offset
;
}
if
(
nbytes
<
1
)
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
MIGRAPHX_THROW
(
"Invalid size for: "
+
filename
);
is
.
seekg
(
0
,
std
::
ios
::
beg
);
is
.
seekg
(
offset
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
T
buffer
(
nbytes
,
0
);
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
nbytes
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
return
buffer
;
}
}
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
)
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
,
size_t
nbytes
)
{
{
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
);
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
,
offset
,
nbytes
);
}
}
std
::
string
read_string
(
const
std
::
string
&
filename
)
std
::
string
read_string
(
const
std
::
string
&
filename
)
...
...
src/fuse_pointwise.cpp
View file @
870a396b
...
@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
...
@@ -39,13 +39,22 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
not
(
s
.
elements
()
=
=
1
or
s
.
scalar
()))
if
(
s
.
elements
()
!
=
1
&&
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
auto
e
=
ins
->
eval
();
auto
e
=
ins
->
eval
();
literal
r
{};
literal
r
{};
e
.
visit_at
([
&
](
auto
x
)
{
r
=
literal
{
x
};
});
// needed for bool as visit_at invokes as() which promotes bool to int8
// Without this we'll break type checks for logical ops that are fused.
if
(
e
.
get_shape
().
type
()
==
shape
::
bool_type
)
{
r
=
literal
{
e
.
at
<
bool
>
()};
}
else
{
e
.
visit_at
([
&
](
auto
x
)
{
r
=
literal
{
x
};
});
}
return
r
;
return
r
;
}
}
...
...
src/include/migraphx/argument.hpp
View file @
870a396b
...
@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
...
@@ -107,6 +107,7 @@ struct argument : raw_data<argument>
data_t
m_data
{};
data_t
m_data
{};
};
};
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
argument
>&
args
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_from_value
(
const
value
&
v
,
argument
&
a
);
void
migraphx_from_value
(
const
value
&
v
,
argument
&
a
);
...
...
src/include/migraphx/check_shapes.hpp
View file @
870a396b
...
@@ -198,7 +198,7 @@ struct check_shapes
...
@@ -198,7 +198,7 @@ struct check_shapes
*/
*/
const
check_shapes
&
same_ndims
()
const
const
check_shapes
&
same_ndims
()
const
{
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
ndim
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
return
*
this
;
}
}
...
...
src/include/migraphx/common.hpp
View file @
870a396b
...
@@ -36,6 +36,9 @@ struct operation;
...
@@ -36,6 +36,9 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
instruction_ref
insert_common_op
(
module
&
m
,
instruction_ref
insert_common_op
(
module
&
m
,
...
...
src/include/migraphx/context.hpp
View file @
870a396b
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
{
return
{};
return
{};
}
}
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
{
}
template
<
class
T
>
void
finish_on_context
(
T
&
,
any_ptr
)
{
}
#ifdef TYPE_ERASED_DECLARATION
#ifdef TYPE_ERASED_DECLARATION
...
@@ -78,6 +87,10 @@ struct context
...
@@ -78,6 +87,10 @@ struct context
void
from_value
(
const
value
&
v
);
void
from_value
(
const
value
&
v
);
// (optional)
// (optional)
any_ptr
get_queue
();
any_ptr
get_queue
();
// (optional)
void
wait_for
(
any_ptr
queue
);
// (optional)
void
finish_on
(
any_ptr
queue
);
//
//
void
finish
()
const
;
void
finish
()
const
;
};
};
...
@@ -165,6 +178,18 @@ struct context
...
@@ -165,6 +178,18 @@ struct context
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
}
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
void
finish
()
const
{
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
@@ -187,6 +212,8 @@ struct context
...
@@ -187,6 +212,8 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
wait_for
(
any_ptr
queue
)
=
0
;
virtual
void
finish_on
(
any_ptr
queue
)
=
0
;
virtual
void
finish
()
const
=
0
;
virtual
void
finish
()
const
=
0
;
};
};
...
@@ -231,6 +258,33 @@ struct context
...
@@ -231,6 +258,33 @@ struct context
return
get_queue_context
(
private_detail_te_self
);
return
get_queue_context
(
private_detail_te_self
);
}
}
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
{
...
@@ -248,7 +302,7 @@ struct context
...
@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT
value
,
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
)
)
:
private_detail_te_value
(
value
)
{
{
}
}
...
@@ -277,6 +331,18 @@ struct context
...
@@ -277,6 +331,18 @@ struct context
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
}
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
src/
targets/gpu/device/conca
t.
c
pp
→
src/
include/migraphx/dyn_outpu
t.
h
pp
View file @
870a396b
...
@@ -21,36 +21,55 @@
...
@@ -21,36 +21,55 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_DYN_OUTPUT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
concat
(
hipStream_t
stream
,
struct
dyn_output
const
migraphx
::
shape
&
,
std
::
vector
<
migraphx
::
argument
>
args
,
std
::
vector
<
std
::
size_t
>
offsets
)
{
{
auto
ninputs
=
args
.
size
()
-
1
;
// original shape from the instruction
for
(
std
::
size_t
j
=
0
;
j
<
ninputs
;
j
++
)
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template
<
class
F
>
struct
compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
()
const
{
{
auto
&&
arg
=
args
[
j
];
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
auto
offset
=
offsets
[
j
];
if
(
ins_shape
.
dynamic
())
auto
byte_offset
=
offset
*
arg
.
get_shape
().
type_size
();
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
auto
output_shape
=
shape
{
return
dyn_output
{
ins_shape
,
ins_shape
};
arg
.
get_shape
().
type
(),
arg
.
get_shape
().
lens
(),
args
.
back
().
get_shape
().
strides
()};
});
auto
output
=
argument
{
output_shape
,
args
.
back
().
data
()
+
byte_offset
};
contiguous
(
stream
,
output
,
arg
);
}
}
return
args
.
back
();
operator
shape
()
const
{
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
}
};
template
<
class
F
>
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
{
return
{
f
};
}
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif
src/
targets/gpu/
include/migraphx/
gpu/add
.hpp
→
src/include/migraphx/
execution_environment
.hpp
View file @
870a396b
...
@@ -21,22 +21,21 @@
...
@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_ADD
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_ADD
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/any_ptr.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
hip_add
:
binary_device
<
hip_add
,
device
::
add
>
struct
execution_environment
{
{
any_ptr
queue
=
any_ptr
{};
bool
async
=
false
;
};
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
/* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
src/include/migraphx/file_buffer.hpp
View file @
870a396b
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
);
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
std
::
string
read_string
(
const
std
::
string
&
filename
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
);
...
...
src/include/migraphx/instruction.hpp
View file @
870a396b
...
@@ -121,6 +121,8 @@ struct instruction
...
@@ -121,6 +121,8 @@ struct instruction
bool
can_eval
()
const
;
bool
can_eval
()
const
;
bool
is_undefined
()
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
argument
eval
(
bool
check_eval
=
true
)
const
;
void
finalize
(
context
&
ctx
);
void
finalize
(
context
&
ctx
);
...
...
src/include/migraphx/layout_nhwc.hpp
View file @
870a396b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
...
...
src/include/migraphx/literal.hpp
View file @
870a396b
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
...
@@ -80,6 +80,7 @@ struct literal : raw_data<literal>
fill
(
start
,
end
);
fill
(
start
,
end
);
}
}
// Directly copies buffer of x
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
sizeof
(
T
)
==
1
)>
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
literal
(
const
shape
&
s
,
T
*
x
)
:
buffer
(
make_shared_array
<
char
>
(
s
.
bytes
())),
m_shape
(
s
)
{
{
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
...
@@ -107,25 +108,15 @@ struct literal : raw_data<literal>
std
::
shared_ptr
<
char
>
buffer
;
std
::
shared_ptr
<
char
>
buffer
;
shape
m_shape
;
shape
m_shape
;
// Keeps the same data ordering as the given container
template
<
class
Iterator
>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
m_shape
.
visit_type
([
&
](
auto
as
)
{
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
std
::
copy
(
start
,
end
,
output
.
begin
());
}
});
else
{
auto
it
=
start
;
m_shape
.
visit_type
([
&
](
auto
as
)
{
auto
output
=
make_view
(
m_shape
,
as
.
from
(
buffer
.
get
()));
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
*
it
;
// NOLINT(bugprone-signed-char-misuse)
it
++
;
});
});
}
}
}
};
};
...
...
Prev
1
2
3
4
5
6
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment