Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f97b8ef
Unverified
Commit
7f97b8ef
authored
Oct 07, 2022
by
Ted Themistokleous
Committed by
GitHub
Oct 07, 2022
Browse files
Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check
parents
2ba401f0
d1fed367
Changes
448
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4508 additions
and
2913 deletions
+4508
-2913
src/apply_alpha_beta.cpp
src/apply_alpha_beta.cpp
+1
-1
src/auto_contiguous.cpp
src/auto_contiguous.cpp
+1
-1
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+4
-3
src/driver/alexnet.cpp
src/driver/alexnet.cpp
+99
-128
src/driver/argument_parser.hpp
src/driver/argument_parser.hpp
+398
-24
src/driver/command.hpp
src/driver/command.hpp
+9
-3
src/driver/inceptionv3.cpp
src/driver/inceptionv3.cpp
+2468
-1772
src/driver/main.cpp
src/driver/main.cpp
+51
-11
src/driver/resnet50.cpp
src/driver/resnet50.cpp
+1355
-935
src/eliminate_concat.cpp
src/eliminate_concat.cpp
+1
-1
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+1
-1
src/file_buffer.cpp
src/file_buffer.cpp
+1
-1
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+16
-0
src/include/migraphx/allocation_model.hpp
src/include/migraphx/allocation_model.hpp
+2
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+16
-15
src/include/migraphx/concat_opt.hpp
src/include/migraphx/concat_opt.hpp
+2
-2
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+69
-3
src/include/migraphx/execution_environment.hpp
src/include/migraphx/execution_environment.hpp
+7
-8
src/include/migraphx/iterator.hpp
src/include/migraphx/iterator.hpp
+2
-2
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+5
-0
No files found.
src/apply_alpha_beta.cpp
View file @
7f97b8ef
...
...
@@ -39,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto
a
=
args
[
0
];
auto
b
=
args
[
1
];
auto
input_type
=
a
->
get_shape
().
type
();
if
(
!
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
if
(
not
float_equal
(
alpha
.
at
<
float
>
(
0
),
1.0
))
{
auto
alpha_literal
=
m
.
add_literal
(
alpha
);
a
=
insert_common_op
(
m
,
pos
,
migraphx
::
make_op
(
"mul"
),
{
alpha_literal
,
a
});
...
...
src/auto_contiguous.cpp
View file @
7f97b8ef
...
...
@@ -63,7 +63,7 @@ void auto_contiguous::apply(module& m) const
if
(
ins
->
outputs
().
empty
()
and
ins
!=
last
)
continue
;
shape
s
=
ins
->
get_shape
();
if
(
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
if
(
not
s
.
dynamic
()
and
not
s
.
standard
()
and
s
.
elements
()
!=
0
)
{
auto
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"contiguous"
),
ins
);
m
.
replace_instruction
(
ins
,
c
);
...
...
src/dead_code_elimination.cpp
View file @
7f97b8ef
...
...
@@ -48,9 +48,10 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction
if
(
i
==
last
)
break
;
// Skip instruction with empty shape as output unless its a builtin, undefined, identity, or
// allocate
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
if
((
not
i
->
get_shape
().
dynamic
()
and
i
->
get_shape
().
elements
()
==
0
)
and
i
->
name
().
front
()
!=
'@'
and
not
contains
({
"undefined"
,
"identity"
,
"allocate"
},
i
->
name
()))
continue
;
assert
(
std
::
distance
(
m
.
begin
(),
i
)
<=
std
::
distance
(
m
.
begin
(),
last
));
...
...
src/driver/alexnet.cpp
View file @
7f97b8ef
...
...
@@ -25,13 +25,10 @@
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/json.hpp>
#include "models.hpp"
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
migraphx
::
program
alexnet
(
unsigned
batch
)
// NOLINT(readability-function-size)
{
migraphx
::
program
p
;
...
...
@@ -42,179 +39,153 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
1
)));
auto
x_main_module_2
=
mmain
->
add_literal
(
migraphx
::
abs
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
2
)));
auto
x_
input_1
=
mmain
->
add_parameter
(
"
input.1
"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch
,
3
,
224
,
224
}});
auto
x_
0
=
mmain
->
add_parameter
(
"
0
"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch
,
3
,
224
,
224
}});
auto
x_main_module_4
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
4096
}},
3
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
}},
3
));
auto
x_main_module_5
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
4
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
,
4096
}},
4
));
auto
x_main_module_6
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
9216
}},
5
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
5
));
auto
x_main_module_7
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
6
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
4096
}},
6
));
auto
x_main_module_8
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
,
4096
}},
7
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
}},
7
));
auto
x_main_module_9
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1000
}},
8
));
auto
x_main_module_10
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
384
,
3
,
3
}},
9
));
auto
x_main_module_11
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
10
));
auto
x_main_module_12
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
,
192
,
3
,
3
}},
11
));
auto
x_main_module_13
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
}},
12
));
auto
x_main_module_14
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
,
64
,
5
,
5
}},
13
));
auto
x_main_module_15
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
}},
14
));
auto
x_main_module_16
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
,
3
,
3
}},
15
));
auto
x_main_module_17
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
16
));
auto
x_main_module_18
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
17
));
auto
x_main_module_19
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
}},
18
));
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4096
,
9216
}},
8
));
auto
x_main_module_10
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
9
));
auto
x_main_module_11
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
256
,
3
,
3
}},
10
));
auto
x_main_module_12
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
}},
11
));
auto
x_main_module_13
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
256
,
384
,
3
,
3
}},
12
));
auto
x_main_module_14
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
}},
13
));
auto
x_main_module_15
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
384
,
192
,
3
,
3
}},
14
));
auto
x_main_module_16
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
}},
15
));
auto
x_main_module_17
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
192
,
64
,
5
,
5
}},
16
));
auto
x_main_module_18
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
}},
17
));
auto
x_main_module_19
=
mmain
->
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
64
,
3
,
11
,
11
}},
18
));
auto
x_main_module_20
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,4]}"
)),
x_input_1
,
x_main_module_18
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,64,55,55]}"
)),
migraphx
::
make_json_op
(
"convolution"
,
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
"4],use_dynamic_same_auto_pad:0}"
),
x_0
,
x_main_module_19
);
auto
x_main_module_21
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,64,55,55]}"
),
x_main_module_18
);
auto
x_main_module_22
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_20
,
x_main_module_21
);
auto
x_main_module_23
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_22
);
auto
x_main_module_24
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
x_main_module_23
);
auto
x_main_module_25
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,1]}"
)),
migraphx
::
make_json_op
(
"convolution"
,
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"
),
x_main_module_24
,
x_main_module_1
4
);
x_main_module_1
7
);
auto
x_main_module_26
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,192,27,27]}"
)),
x_main_module_15
);
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,192,27,27]}"
),
x_main_module_16
);
auto
x_main_module_27
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_25
,
x_main_module_26
);
auto
x_main_module_28
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_27
);
auto
x_main_module_29
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
x_main_module_28
);
auto
x_main_module_30
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
migraphx
::
make_json_op
(
"convolution"
,
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"
),
x_main_module_29
,
x_main_module_1
2
);
x_main_module_1
5
);
auto
x_main_module_31
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,384,13,13]}"
)),
x_main_module_13
);
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,384,13,13]}"
),
x_main_module_14
);
auto
x_main_module_32
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_30
,
x_main_module_31
);
auto
x_main_module_33
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_32
);
auto
x_main_module_34
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
migraphx
::
make_json_op
(
"convolution"
,
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"
),
x_main_module_33
,
x_main_module_1
0
);
x_main_module_1
3
);
auto
x_main_module_35
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,256,13,13]}"
)),
x_main_module_11
);
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,256,13,13]}"
),
x_main_module_12
);
auto
x_main_module_36
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_34
,
x_main_module_35
);
auto
x_main_module_37
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_36
);
auto
x_main_module_38
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
migraphx
::
from_json_string
(
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,1]}"
)),
migraphx
::
make_json_op
(
"convolution"
,
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"
),
x_main_module_37
,
x_main_module_1
6
);
x_main_module_1
1
);
auto
x_main_module_39
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
migraphx
::
from_json_string
(
"{axis:1,out_lens:[1,256,13,13]}"
)),
x_main_module_17
);
migraphx
::
make_json_op
(
"broadcast"
,
"{axis:1,out_lens:[1,256,13,13]}"
),
x_main_module_10
);
auto
x_main_module_40
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_38
,
x_main_module_39
);
auto
x_main_module_41
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_40
);
auto
x_main_module_42
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_
json_
op
(
"pooling"
,
migraphx
::
from_json_string
(
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
)),
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"
),
x_main_module_41
);
auto
x_main_module_43
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
migraphx
::
from_json_string
(
"{dims:[1,9216]}"
)),
x_main_module_42
);
auto
x_main_module_44
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
x_main_module_6
);
auto
x_main_module_45
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_43
,
x_main_module_44
);
auto
x_main_module_46
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
x_main_module_7
);
auto
x_main_module_43
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"flatten"
,
"{axis:1}"
),
x_main_module_42
);
auto
x_main_module_44
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_main_module_43
);
auto
x_main_module_45
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_9
);
auto
x_main_module_46
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_44
,
x_main_module_45
);
auto
x_main_module_47
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
x_main_module_2
);
auto
x_main_module_48
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_46
,
x_main_module_47
);
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_8
);
auto
x_main_module_48
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_2
);
auto
x_main_module_49
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_45
,
x_main_module_48
);
auto
x_main_module_50
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_49
);
auto
x_main_module_51
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
x_main_module_4
);
auto
x_main_module_52
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_50
,
x_main_module_51
);
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_47
,
x_main_module_48
);
auto
x_main_module_50
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_46
,
x_main_module_49
);
auto
x_main_module_51
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_50
);
auto
x_main_module_52
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_main_module_51
);
auto
x_main_module_53
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
x_main_module_5
);
auto
x_main_module_54
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,4096]}"
)),
x_main_module_1
);
auto
x_main_module_55
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_53
,
x_main_module_54
);
auto
x_main_module_56
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_52
,
x_main_module_55
);
auto
x_main_module_57
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_56
);
auto
x_main_module_58
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
migraphx
::
from_json_string
(
"{permutation:[1,0]}"
)),
x_main_module_8
);
auto
x_main_module_59
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_57
,
x_main_module_58
);
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_7
);
auto
x_main_module_54
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_52
,
x_main_module_53
);
auto
x_main_module_55
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_6
);
auto
x_main_module_56
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,4096]}"
),
x_main_module_1
);
auto
x_main_module_57
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_55
,
x_main_module_56
);
auto
x_main_module_58
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_54
,
x_main_module_57
);
auto
x_main_module_59
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"relu"
),
x_main_module_58
);
auto
x_main_module_60
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,1000]}"
)),
x_main_module_9
);
auto
x_main_module_61
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
migraphx
::
from_json_string
(
"{out_lens:[1,1000]}"
)),
x_main_module_0
);
auto
x_main_module_62
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_60
,
x_main_module_61
);
auto
x_main_module_63
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_59
,
x_main_module_62
);
mmain
->
add_return
({
x_main_module_63
});
migraphx
::
make_json_op
(
"transpose"
,
"{permutation:[1,0]}"
),
x_main_module_5
);
auto
x_main_module_61
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
x_main_module_59
,
x_main_module_60
);
auto
x_main_module_62
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,1000]}"
),
x_main_module_4
);
auto
x_main_module_63
=
mmain
->
add_instruction
(
migraphx
::
make_json_op
(
"multibroadcast"
,
"{out_lens:[1,1000]}"
),
x_main_module_0
);
auto
x_main_module_64
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_main_module_62
,
x_main_module_63
);
auto
x_main_module_65
=
mmain
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_main_module_61
,
x_main_module_64
);
mmain
->
add_return
({
x_main_module_65
});
return
p
;
}
...
...
src/driver/argument_parser.hpp
View file @
7f97b8ef
...
...
@@ -27,11 +27,13 @@
#include <algorithm>
#include <functional>
#include <iostream>
#include <list>
#include <set>
#include <string>
#include <sstream>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
...
...
@@ -39,9 +41,16 @@
#include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/rank.hpp>
#ifndef _WIN32
#include <unistd.h>
#endif
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -74,6 +83,65 @@ template <class T>
using
is_multi_value
=
std
::
integral_constant
<
bool
,
(
is_container
<
T
>
{}
and
not
std
::
is_convertible
<
T
,
std
::
string
>
{})
>
;
enum
class
color
{
reset
=
0
,
bold
=
1
,
underlined
=
4
,
fg_red
=
31
,
fg_green
=
32
,
fg_yellow
=
33
,
fg_blue
=
34
,
fg_default
=
39
,
bg_red
=
41
,
bg_green
=
42
,
bg_yellow
=
43
,
bg_blue
=
44
,
bg_default
=
49
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
color
&
c
)
{
#ifndef _WIN32
static
const
bool
use_color
=
isatty
(
STDOUT_FILENO
)
!=
0
;
if
(
use_color
)
return
os
<<
"
\033
["
<<
static_cast
<
std
::
size_t
>
(
c
)
<<
"m"
;
#endif
return
os
;
}
inline
std
::
string
colorize
(
color
c
,
const
std
::
string
&
s
)
{
std
::
stringstream
ss
;
ss
<<
c
<<
s
<<
color
::
reset
;
return
ss
.
str
();
}
template
<
class
T
>
struct
type_name
{
static
const
std
::
string
&
apply
()
{
return
migraphx
::
get_type_name
<
T
>
();
}
};
template
<
>
struct
type_name
<
std
::
string
>
{
static
const
std
::
string
&
apply
()
{
static
const
std
::
string
name
=
"std::string"
;
return
name
;
}
};
template
<
class
T
>
struct
type_name
<
std
::
vector
<
T
>>
{
static
const
std
::
string
&
apply
()
{
static
const
std
::
string
name
=
"std::vector<"
+
type_name
<
T
>::
apply
()
+
">"
;
return
name
;
}
};
template
<
class
T
>
struct
value_parser
{
...
...
@@ -85,7 +153,7 @@ struct value_parser
ss
.
str
(
x
);
ss
>>
result
;
if
(
ss
.
fail
())
throw
std
::
runtime_error
(
"Failed to parse
:
"
+
x
);
throw
std
::
runtime_error
(
"Failed to parse
'
"
+
x
+
"' as "
+
type_name
<
T
>::
apply
()
);
return
result
;
}
...
...
@@ -97,7 +165,7 @@ struct value_parser
ss
.
str
(
x
);
ss
>>
i
;
if
(
ss
.
fail
())
throw
std
::
runtime_error
(
"Failed to parse
:
"
+
x
);
throw
std
::
runtime_error
(
"Failed to parse
'
"
+
x
+
"' as "
+
type_name
<
T
>::
apply
()
);
return
static_cast
<
T
>
(
i
);
}
...
...
@@ -115,13 +183,42 @@ struct argument_parser
{
struct
argument
{
using
action_function
=
std
::
function
<
bool
(
argument_parser
&
,
const
std
::
vector
<
std
::
string
>&
)
>
;
using
validate_function
=
std
::
function
<
void
(
const
argument_parser
&
,
const
std
::
vector
<
std
::
string
>&
)
>
;
std
::
vector
<
std
::
string
>
flags
;
std
::
function
<
bool
(
argument_parser
&
,
const
std
::
vector
<
std
::
string
>&
)
>
action
{};
action_function
action
{};
std
::
string
type
=
""
;
std
::
string
help
=
""
;
std
::
string
metavar
=
""
;
std
::
string
default_value
=
""
;
std
::
string
group
=
""
;
unsigned
nargs
=
1
;
bool
required
=
false
;
std
::
vector
<
validate_function
>
validations
{};
std
::
string
usage
(
const
std
::
string
&
flag
)
const
{
std
::
stringstream
ss
;
if
(
flag
.
empty
())
{
ss
<<
metavar
;
}
else
{
ss
<<
flag
;
if
(
not
type
.
empty
())
ss
<<
" ["
<<
type
<<
"]"
;
}
return
ss
.
str
();
}
std
::
string
usage
()
const
{
if
(
flags
.
empty
())
return
usage
(
""
);
return
usage
(
flags
.
front
());
}
};
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_multi_value
<
T
>{})
>
...
...
@@ -154,12 +251,14 @@ struct argument_parser
arguments
.
push_back
({
flags
,
[
&
](
auto
&&
,
const
std
::
vector
<
std
::
string
>&
params
)
{
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"Flag with no value."
);
if
(
not
is_multi_value
<
T
>
{}
and
params
.
size
()
>
1
)
throw
std
::
runtime_error
(
"Too many arguments passed."
);
x
=
value_parser
<
T
>::
apply
(
params
.
back
());
return
false
;
}});
argument
&
arg
=
arguments
.
back
();
arg
.
type
=
migraphx
::
get_
type_name
<
T
>
();
arg
.
type
=
type_name
<
T
>
::
apply
();
migraphx
::
each_args
([
&
](
auto
f
)
{
f
(
x
,
arg
);
},
fs
...);
if
(
not
arg
.
default_value
.
empty
()
and
arg
.
nargs
>
0
)
arg
.
default_value
=
as_string_value
(
x
);
...
...
@@ -181,6 +280,11 @@ struct argument_parser
return
[
=
](
auto
&&
,
auto
&
arg
)
{
arg
.
nargs
=
n
;
};
}
MIGRAPHX_DRIVER_STATIC
auto
required
()
{
return
[
=
](
auto
&&
,
auto
&
arg
)
{
arg
.
required
=
true
;
};
}
template
<
class
F
>
MIGRAPHX_DRIVER_STATIC
auto
write_action
(
F
f
)
{
...
...
@@ -215,34 +319,164 @@ struct argument_parser
});
}
MIGRAPHX_DRIVER_STATIC
auto
show_help
(
const
std
::
string
&
msg
=
""
)
template
<
class
F
>
MIGRAPHX_DRIVER_STATIC
auto
validate
(
F
f
)
{
return
[
=
](
const
auto
&
x
,
auto
&
arg
)
{
arg
.
validations
.
push_back
(
[
&
,
f
](
auto
&
self
,
const
std
::
vector
<
std
::
string
>&
params
)
{
f
(
self
,
x
,
params
);
});
};
}
MIGRAPHX_DRIVER_STATIC
auto
file_exist
()
{
return
validate
([](
auto
&
,
auto
&
,
auto
&
params
)
{
if
(
params
.
empty
())
throw
std
::
runtime_error
(
"No argument passed."
);
if
(
not
fs
::
exists
(
params
.
back
()))
throw
std
::
runtime_error
(
"Path does not exists: "
+
params
.
back
());
});
}
template
<
class
F
>
argument
*
find_argument
(
F
f
)
{
auto
it
=
std
::
find_if
(
arguments
.
begin
(),
arguments
.
end
(),
f
);
if
(
it
==
arguments
.
end
())
return
nullptr
;
return
std
::
addressof
(
*
it
);
}
template
<
class
F
>
bool
has_argument
(
F
f
)
{
return
find_argument
(
f
)
!=
nullptr
;
}
template
<
class
F
>
std
::
vector
<
argument
*>
find_arguments
(
F
f
)
{
std
::
vector
<
argument
*>
result
;
for
(
auto
&
arg
:
arguments
)
{
if
(
not
f
(
arg
))
continue
;
result
.
push_back
(
&
arg
);
}
return
result
;
}
std
::
vector
<
argument
*>
get_group_arguments
(
const
std
::
string
&
group
)
{
return
find_arguments
([
&
](
const
auto
&
arg
)
{
return
arg
.
group
==
group
;
});
}
std
::
vector
<
argument
*>
get_required_arguments
()
{
return
find_arguments
([
&
](
const
auto
&
arg
)
{
return
arg
.
required
;
});
}
template
<
class
SequenceContainer
>
std
::
vector
<
std
::
string
>
get_argument_usages
(
SequenceContainer
args
)
{
std
::
vector
<
std
::
string
>
usage_flags
;
std
::
unordered_set
<
std
::
string
>
found_groups
;
// Remove arguments that belong to a group
auto
it
=
std
::
remove_if
(
args
.
begin
(),
args
.
end
(),
[
&
](
const
argument
*
arg
)
{
if
(
arg
->
group
.
empty
())
return
false
;
found_groups
.
insert
(
arg
->
group
);
return
true
;
});
args
.
erase
(
it
,
args
.
end
());
transform
(
found_groups
,
std
::
back_inserter
(
usage_flags
),
[
&
](
auto
&&
group
)
{
std
::
vector
<
std
::
string
>
either_flags
;
transform
(
get_group_arguments
(
group
),
std
::
back_inserter
(
either_flags
),
[](
auto
*
arg
)
{
return
arg
->
usage
();
});
return
"("
+
join_strings
(
either_flags
,
"|"
)
+
")"
;
});
transform
(
args
,
std
::
back_inserter
(
usage_flags
),
[
&
](
auto
*
arg
)
{
return
arg
->
usage
();
});
return
usage_flags
;
}
auto
show_help
(
const
std
::
string
&
msg
=
""
)
{
return
do_action
([
=
](
auto
&
self
)
{
for
(
auto
&&
arg
:
self
.
arguments
)
argument
*
input_argument
=
self
.
find_argument
([](
const
auto
&
arg
)
{
return
arg
.
flags
.
empty
();
});
auto
required_usages
=
get_argument_usages
(
get_required_arguments
());
if
(
required_usages
.
empty
()
&&
input_argument
)
required_usages
.
push_back
(
input_argument
->
metavar
);
required_usages
.
insert
(
required_usages
.
begin
(),
"<options>"
);
print_usage
(
required_usages
);
std
::
cout
<<
std
::
endl
;
if
(
self
.
find_argument
([](
const
auto
&
arg
)
{
return
arg
.
nargs
==
0
;
}))
{
std
::
cout
<<
color
::
fg_yellow
<<
"FLAGS:"
<<
color
::
reset
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
std
::
string
prefix
=
" "
;
if
(
arg
.
flags
.
empty
())
{
std
::
cout
<<
prefix
;
std
::
cout
<<
arg
.
metavar
;
}
for
(
const
std
::
string
&
a
:
arg
.
flags
)
for
(
auto
&&
arg
:
self
.
arguments
)
{
std
::
cout
<<
prefix
;
std
::
cout
<<
a
;
prefix
=
", "
;
if
(
arg
.
nargs
!=
0
)
continue
;
const
int
col_align
=
35
;
std
::
string
prefix
=
" "
;
int
len
=
0
;
std
::
cout
<<
color
::
fg_green
;
for
(
const
std
::
string
&
a
:
arg
.
flags
)
{
len
+=
prefix
.
length
()
+
a
.
length
();
std
::
cout
<<
prefix
;
std
::
cout
<<
a
;
prefix
=
", "
;
}
std
::
cout
<<
color
::
reset
;
int
spaces
=
col_align
-
len
;
if
(
spaces
<
0
)
{
std
::
cout
<<
std
::
endl
;
}
else
{
for
(
int
i
=
0
;
i
<
spaces
;
i
++
)
std
::
cout
<<
" "
;
}
std
::
cout
<<
arg
.
help
<<
std
::
endl
;
}
if
(
not
arg
.
type
.
empty
())
std
::
cout
<<
std
::
endl
;
}
if
(
self
.
find_argument
([](
const
auto
&
arg
)
{
return
arg
.
nargs
!=
0
;
}))
{
std
::
cout
<<
color
::
fg_yellow
<<
"OPTIONS:"
<<
color
::
reset
<<
std
::
endl
;
for
(
auto
&&
arg
:
self
.
arguments
)
{
std
::
cout
<<
" ["
<<
arg
.
type
<<
"]"
;
if
(
not
arg
.
default_value
.
empty
())
std
::
cout
<<
" (Default: "
<<
arg
.
default_value
<<
")"
;
if
(
arg
.
nargs
==
0
)
continue
;
std
::
cout
<<
std
::
endl
;
std
::
string
prefix
=
" "
;
std
::
cout
<<
color
::
fg_green
;
if
(
arg
.
flags
.
empty
())
{
std
::
cout
<<
prefix
;
std
::
cout
<<
arg
.
metavar
;
}
for
(
const
std
::
string
&
a
:
arg
.
flags
)
{
std
::
cout
<<
prefix
;
std
::
cout
<<
a
;
prefix
=
", "
;
}
std
::
cout
<<
color
::
reset
;
if
(
not
arg
.
type
.
empty
())
{
std
::
cout
<<
" ["
<<
color
::
fg_blue
<<
arg
.
type
<<
color
::
reset
<<
"]"
;
if
(
not
arg
.
default_value
.
empty
())
std
::
cout
<<
" (Default: "
<<
arg
.
default_value
<<
")"
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
" "
<<
arg
.
help
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
" "
<<
arg
.
help
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
if
(
not
msg
.
empty
())
std
::
cout
<<
msg
<<
std
::
endl
;
});
...
...
@@ -263,6 +497,11 @@ struct argument_parser
return
[
=
](
auto
&
,
auto
&
arg
)
{
arg
.
type
=
type
;
};
}
MIGRAPHX_DRIVER_STATIC
auto
group
(
const
std
::
string
&
group
)
{
return
[
=
](
auto
&
,
auto
&
arg
)
{
arg
.
group
=
group
;
};
}
template
<
class
T
>
MIGRAPHX_DRIVER_STATIC
auto
set_value
(
T
value
)
{
...
...
@@ -276,6 +515,109 @@ struct argument_parser
};
}
template
<
class
T
>
void
set_exe_name_to
(
T
&
x
)
{
actions
.
push_back
([
&
](
const
auto
&
self
)
{
x
=
self
.
exe_name
;
});
}
void
print_try_help
()
{
if
(
has_argument
([](
const
auto
&
a
)
{
return
contains
(
a
.
flags
,
"--help"
);
}))
{
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"For more information try '"
<<
color
::
fg_green
<<
"--help"
<<
color
::
reset
<<
"'"
<<
std
::
endl
;
}
}
void
print_usage
(
const
std
::
vector
<
std
::
string
>&
flags
)
const
{
std
::
cout
<<
color
::
fg_yellow
<<
"USAGE:"
<<
color
::
reset
<<
std
::
endl
;
std
::
cout
<<
" "
<<
exe_name
<<
" "
;
std
::
cout
<<
join_strings
(
flags
,
" "
)
<<
std
::
endl
;
}
auto
spellcheck
(
const
std
::
vector
<
std
::
string
>&
inputs
)
{
struct
result_t
{
const
argument
*
arg
=
nullptr
;
std
::
string
correct
=
""
;
std
::
string
incorrect
=
""
;
std
::
ptrdiff_t
distance
=
std
::
numeric_limits
<
std
::
ptrdiff_t
>::
max
();
};
result_t
result
;
for
(
const
auto
&
input
:
inputs
)
{
if
(
input
.
empty
())
continue
;
if
(
input
[
0
]
!=
'-'
)
continue
;
for
(
const
auto
&
arg
:
arguments
)
{
for
(
const
auto
&
flag
:
arg
.
flags
)
{
if
(
flag
.
empty
())
continue
;
if
(
flag
[
0
]
!=
'-'
)
continue
;
auto
d
=
levenshtein_distance
(
flag
.
begin
(),
flag
.
end
(),
input
.
begin
(),
input
.
end
());
if
(
d
<
result
.
distance
)
result
=
result_t
{
&
arg
,
flag
,
input
,
d
};
}
}
}
return
result
;
}
bool
run_action
(
const
argument
&
arg
,
const
std
::
string
&
flag
,
const
std
::
vector
<
std
::
string
>&
inputs
)
{
std
::
string
msg
=
""
;
try
{
for
(
const
auto
&
v
:
arg
.
validations
)
v
(
*
this
,
inputs
);
return
arg
.
action
(
*
this
,
inputs
);
}
catch
(
const
std
::
exception
&
e
)
{
msg
=
e
.
what
();
}
catch
(...)
{
msg
=
"unknown exception"
;
}
std
::
cout
<<
color
::
fg_red
<<
color
::
bold
<<
"error: "
<<
color
::
reset
;
auto
sc
=
spellcheck
(
inputs
);
if
(
sc
.
distance
<
5
)
{
std
::
cout
<<
"Found argument '"
<<
color
::
fg_yellow
<<
sc
.
incorrect
<<
color
::
reset
<<
"'"
<<
" which wasn't expected, or isn't valid in this context"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
"Did you mean "
<<
color
::
fg_green
<<
sc
.
correct
<<
color
::
reset
<<
"?"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
print_usage
({
sc
.
arg
->
usage
(
sc
.
correct
)});
}
else
{
const
auto
&
flag_name
=
flag
.
empty
()
?
arg
.
metavar
:
flag
;
std
::
cout
<<
"Invalid input to '"
<<
color
::
fg_yellow
;
std
::
cout
<<
arg
.
usage
(
flag_name
);
std
::
cout
<<
color
::
reset
<<
"'"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
msg
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
print_usage
({
arg
.
usage
()});
}
std
::
cout
<<
std
::
endl
;
print_try_help
();
return
true
;
}
bool
parse
(
std
::
vector
<
std
::
string
>
args
)
{
std
::
unordered_map
<
std
::
string
,
unsigned
>
keywords
;
...
...
@@ -286,8 +628,11 @@ struct argument_parser
}
auto
arg_map
=
generic_parse
(
std
::
move
(
args
),
[
&
](
const
std
::
string
&
x
)
{
return
keywords
[
x
];
});
std
::
list
<
const
argument
*>
missing_arguments
;
std
::
unordered_set
<
std
::
string
>
groups_used
;
for
(
auto
&&
arg
:
arguments
)
{
bool
used
=
false
;
auto
flags
=
arg
.
flags
;
if
(
flags
.
empty
())
flags
=
{
""
};
...
...
@@ -295,14 +640,41 @@ struct argument_parser
{
if
(
arg_map
.
count
(
flag
)
>
0
)
{
if
(
arg
.
action
(
*
this
,
arg_map
[
flag
]))
if
(
run_
action
(
arg
,
flag
,
arg_map
[
flag
]))
return
true
;
used
=
true
;
}
}
if
(
used
and
not
arg
.
group
.
empty
())
groups_used
.
insert
(
arg
.
group
);
if
(
arg
.
required
and
not
used
)
missing_arguments
.
push_back
(
&
arg
);
}
// Remove arguments from a group that is being used
missing_arguments
.
remove_if
(
[
&
](
const
argument
*
arg
)
{
return
groups_used
.
count
(
arg
->
group
);
});
if
(
not
missing_arguments
.
empty
())
{
std
::
cout
<<
color
::
fg_red
<<
color
::
bold
<<
"error: "
<<
color
::
reset
;
std
::
cout
<<
"The following required arguments were not provided:"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
color
::
fg_red
<<
join_strings
(
get_argument_usages
(
std
::
move
(
missing_arguments
)),
" "
)
<<
color
::
reset
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
auto
required_usages
=
get_argument_usages
(
get_required_arguments
());
print_usage
(
required_usages
);
print_try_help
();
return
true
;
}
for
(
auto
&&
action
:
actions
)
action
(
*
this
);
return
false
;
}
void
set_exe_name
(
const
std
::
string
&
s
)
{
exe_name
=
s
;
}
const
std
::
string
&
get_exe_name
()
const
{
return
exe_name
;
}
using
string_map
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
template
<
class
IsKeyword
>
static
string_map
generic_parse
(
std
::
vector
<
std
::
string
>
as
,
IsKeyword
is_keyword
)
...
...
@@ -337,7 +709,9 @@ struct argument_parser
}
private:
std
::
vector
<
argument
>
arguments
;
std
::
list
<
argument
>
arguments
;
std
::
string
exe_name
=
""
;
std
::
vector
<
std
::
function
<
void
(
argument_parser
&
)
>>
actions
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/driver/command.hpp
View file @
7f97b8ef
...
...
@@ -41,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline
auto
&
get_commands
()
{
// NOLINTNEXTLINE
static
std
::
unordered_map
<
std
::
string
,
std
::
function
<
void
(
std
::
vector
<
std
::
string
>
args
)
>>
m
;
static
std
::
unordered_map
<
std
::
string
,
std
::
function
<
void
(
const
std
::
string
&
exe_name
,
std
::
vector
<
std
::
string
>
args
)
>>
m
;
return
m
;
}
...
...
@@ -65,10 +68,11 @@ const std::string& command_name()
}
template
<
class
T
>
void
run_command
(
std
::
vector
<
std
::
string
>
args
,
bool
add_help
=
false
)
void
run_command
(
const
std
::
string
&
exe_name
,
std
::
vector
<
std
::
string
>
args
,
bool
add_help
=
false
)
{
T
x
;
argument_parser
ap
;
ap
.
set_exe_name
(
exe_name
+
" "
+
command_name
<
T
>
());
if
(
add_help
)
ap
(
nullptr
,
{
"-h"
,
"--help"
},
ap
.
help
(
"Show help"
),
ap
.
show_help
());
x
.
parse
(
ap
);
...
...
@@ -81,7 +85,9 @@ template <class T>
int
auto_register_command
()
{
auto
&
m
=
get_commands
();
m
[
command_name
<
T
>
()]
=
[](
std
::
vector
<
std
::
string
>
args
)
{
run_command
<
T
>
(
args
,
true
);
};
m
[
command_name
<
T
>
()]
=
[](
const
std
::
string
&
exe_name
,
std
::
vector
<
std
::
string
>
args
)
{
run_command
<
T
>
(
exe_name
,
args
,
true
);
};
return
0
;
}
...
...
src/driver/inceptionv3.cpp
View file @
7f97b8ef
This diff is collapsed.
Click to expand it.
src/driver/main.cpp
View file @
7f97b8ef
...
...
@@ -73,8 +73,12 @@ struct loader
void
parse
(
argument_parser
&
ap
)
{
ap
(
file
,
{},
ap
.
metavar
(
"<input file>"
));
ap
(
model
,
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
));
ap
(
file
,
{},
ap
.
metavar
(
"<input file>"
),
ap
.
file_exist
(),
ap
.
required
(),
ap
.
group
(
"input"
));
ap
(
model
,
{
"--model"
},
ap
.
help
(
"Load model"
),
ap
.
type
(
"resnet50|inceptionv3|alexnet"
),
ap
.
group
(
"input"
));
ap
(
file_type
,
{
"--onnx"
},
ap
.
help
(
"Load as onnx"
),
ap
.
set_value
(
"onnx"
));
ap
(
file_type
,
{
"--tf"
},
ap
.
help
(
"Load as tensorflow"
),
ap
.
set_value
(
"tf"
));
ap
(
file_type
,
{
"--migraphx"
},
ap
.
help
(
"Load as MIGraphX"
),
ap
.
set_value
(
"migraphx"
));
...
...
@@ -578,26 +582,62 @@ struct onnx : command<onnx>
struct
main_command
{
static
std
::
string
get_command_help
()
static
std
::
string
get_command_help
(
const
std
::
string
&
title
=
colorize
(
color
::
fg_yellow
,
"COMMANDS:"
))
{
std
::
string
result
=
"Commands:
\n
"
;
return
std
::
accumulate
(
get_commands
().
begin
(),
get_commands
().
end
(),
result
,
[](
auto
r
,
auto
&&
p
)
{
return
r
+
" "
+
p
.
first
+
"
\n
"
;
});
std
::
string
result
=
title
+
"
\n
"
;
std
::
vector
<
std
::
string
>
commands
(
get_commands
().
size
());
std
::
transform
(
get_commands
().
begin
(),
get_commands
().
end
(),
commands
.
begin
(),
[](
const
auto
&
p
)
{
return
colorize
(
color
::
fg_green
,
p
.
first
);
});
std
::
sort
(
commands
.
begin
(),
commands
.
end
());
return
std
::
accumulate
(
commands
.
begin
(),
commands
.
end
(),
result
,
[](
auto
r
,
auto
&&
s
)
{
return
r
+
" "
+
s
+
"
\n
"
;
});
}
void
parse
(
argument_parser
&
ap
)
{
std
::
string
version_str
=
"MIGraphX Version: "
+
std
::
to_string
(
MIGRAPHX_VERSION_MAJOR
)
+
"."
+
std
::
to_string
(
MIGRAPHX_VERSION_MINOR
);
ap
(
wrong_commands
,
{},
ap
.
metavar
(
"<command>"
),
ap
.
append
());
ap
(
nullptr
,
{
"-h"
,
"--help"
},
ap
.
help
(
"Show help"
),
ap
.
show_help
(
get_command_help
()));
ap
(
nullptr
,
{
"-v"
,
"--version"
},
ap
.
help
(
"Show MIGraphX version"
),
ap
.
show_help
(
version_str
));
// Trim command off of exe name
ap
.
set_exe_name
(
ap
.
get_exe_name
().
substr
(
0
,
ap
.
get_exe_name
().
size
()
-
5
));
ap
.
set_exe_name_to
(
exe_name
);
}
void
run
()
{}
std
::
vector
<
std
::
string
>
wrong_commands
{};
std
::
string
exe_name
=
"<exe>"
;
void
run
()
{
std
::
cout
<<
color
::
fg_red
<<
color
::
bold
<<
"error: "
<<
color
::
reset
;
auto
it
=
std
::
find_if
(
wrong_commands
.
begin
(),
wrong_commands
.
end
(),
[](
const
auto
&
c
)
{
return
get_commands
().
count
(
c
)
>
0
;
});
if
(
it
==
wrong_commands
.
end
())
{
std
::
cout
<<
"'"
<<
color
::
fg_yellow
<<
wrong_commands
.
front
()
<<
color
::
reset
<<
"' is not a valid command."
<<
std
::
endl
;
std
::
cout
<<
get_command_help
(
"Available commands:"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"command '"
<<
color
::
fg_yellow
<<
*
it
<<
color
::
reset
<<
"' must be first argument"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
color
::
fg_yellow
<<
"USAGE:"
<<
color
::
reset
<<
std
::
endl
;
std
::
cout
<<
" "
<<
exe_name
<<
" "
<<
*
it
<<
" <options>"
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
}
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
@@ -619,11 +659,11 @@ int main(int argc, const char* argv[])
auto
cmd
=
args
.
front
();
if
(
m
.
count
(
cmd
)
>
0
)
{
m
.
at
(
cmd
)({
args
.
begin
()
+
1
,
args
.
end
()});
m
.
at
(
cmd
)(
argv
[
0
],
{
args
.
begin
()
+
1
,
args
.
end
()});
}
else
{
run_command
<
main_command
>
(
args
);
run_command
<
main_command
>
(
argv
[
0
],
args
);
}
return
0
;
...
...
src/driver/resnet50.cpp
View file @
7f97b8ef
This diff is collapsed.
Click to expand it.
src/eliminate_concat.cpp
View file @
7f97b8ef
...
...
@@ -60,7 +60,7 @@ void eliminate_concat::apply(module& m) const
auto
lens
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
concat_op
=
concat_opt
.
get_concat
(
ins
->
get_operator
());
std
::
size_t
axis_index
=
tune_axis
(
lens
.
size
(),
concat_op
.
axis
,
concat_op
.
name
());
if
(
axis_index
==
0
||
if
(
axis_index
==
0
or
std
::
all_of
(
lens
.
begin
(),
lens
.
begin
()
+
axis_index
,
[](
auto
x
)
{
return
x
==
1
;
}))
{
// Last input should be an allocation
...
...
src/eliminate_contiguous.cpp
View file @
7f97b8ef
...
...
@@ -71,7 +71,7 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
if
(
!
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
{
return
false
;
}
...
...
src/file_buffer.cpp
View file @
7f97b8ef
...
...
@@ -39,7 +39,7 @@ T generic_read_file(const std::string& filename)
is
.
seekg
(
0
,
std
::
ios
::
beg
);
T
buffer
(
size
,
0
);
if
(
!
is
.
read
(
&
buffer
[
0
],
size
))
if
(
not
is
.
read
(
&
buffer
[
0
],
size
))
MIGRAPHX_THROW
(
"Error reading file: "
+
filename
);
return
buffer
;
}
...
...
src/include/migraphx/algorithm.hpp
View file @
7f97b8ef
...
...
@@ -74,6 +74,22 @@ void group_unique(Iterator start, Iterator last, Output out, Predicate pred)
}
}
template
<
class
Iterator1
,
class
Iterator2
>
std
::
ptrdiff_t
levenshtein_distance
(
Iterator1
first1
,
Iterator1
last1
,
Iterator2
first2
,
Iterator2
last2
)
{
if
(
first1
==
last1
)
return
std
::
distance
(
first2
,
last2
);
if
(
first2
==
last2
)
return
std
::
distance
(
first1
,
last1
);
if
(
*
first1
==
*
first2
)
return
levenshtein_distance
(
std
::
next
(
first1
),
last1
,
std
::
next
(
first2
),
last2
);
auto
x1
=
levenshtein_distance
(
std
::
next
(
first1
),
last1
,
std
::
next
(
first2
),
last2
);
auto
x2
=
levenshtein_distance
(
first1
,
last1
,
std
::
next
(
first2
),
last2
);
auto
x3
=
levenshtein_distance
(
std
::
next
(
first1
),
last1
,
first2
,
last2
);
return
std
::
ptrdiff_t
{
1
}
+
std
::
min
({
x1
,
x2
,
x3
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/allocation_model.hpp
View file @
7f97b8ef
...
...
@@ -205,7 +205,7 @@ struct allocation_model
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -267,7 +267,7 @@ struct allocation_model
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/check_shapes.hpp
View file @
7f97b8ef
...
...
@@ -101,7 +101,7 @@ struct check_shapes
const
check_shapes
&
nelements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes must have only "
+
std
::
to_string
(
n
)
+
" elements"
);
return
*
this
;
}
...
...
@@ -164,7 +164,7 @@ struct check_shapes
*/
const
check_shapes
&
same_shape
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes do not match"
);
return
*
this
;
}
...
...
@@ -174,7 +174,7 @@ struct check_shapes
*/
const
check_shapes
&
same_type
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
type
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Types do not match"
);
return
*
this
;
}
...
...
@@ -184,10 +184,10 @@ struct check_shapes
*/
const
check_shapes
&
same_dims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Dimensions do not match"
);
if
(
this
->
any_of
([
&
](
const
shape
&
s
)
{
return
s
.
dynamic
();
}))
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
min_lens
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Min dynamic dimensions do not match"
);
return
*
this
;
}
...
...
@@ -197,7 +197,7 @@ struct check_shapes
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
!
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
...
...
@@ -207,7 +207,7 @@ struct check_shapes
*/
const
check_shapes
&
standard
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not in standard layout"
);
return
*
this
;
}
...
...
@@ -217,7 +217,7 @@ struct check_shapes
*/
const
check_shapes
&
standard_or_scalar
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
standard
()
or
s
.
scalar
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not a scalar or in standard layout"
);
return
*
this
;
}
...
...
@@ -227,7 +227,7 @@ struct check_shapes
*/
const
check_shapes
&
packed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed"
);
return
*
this
;
}
...
...
@@ -237,7 +237,7 @@ struct check_shapes
*/
const
check_shapes
&
packed_or_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
packed
()
or
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not packed nor broadcasted"
);
return
*
this
;
}
...
...
@@ -247,7 +247,7 @@ struct check_shapes
*/
const
check_shapes
&
tuple_type
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
s
.
type
()
==
shape
::
tuple_type
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are not tuple!"
);
return
*
this
;
}
...
...
@@ -257,7 +257,7 @@ struct check_shapes
*/
const
check_shapes
&
not_transposed
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
transposed
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are transposed"
);
return
*
this
;
}
...
...
@@ -267,7 +267,7 @@ struct check_shapes
*/
const
check_shapes
&
not_broadcasted
()
const
{
if
(
!
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
if
(
not
this
->
all_of
([](
const
shape
&
s
)
{
return
not
s
.
broadcasted
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Shapes are broadcasted"
);
return
*
this
;
}
...
...
@@ -278,7 +278,7 @@ struct check_shapes
*/
const
check_shapes
&
elements
(
std
::
size_t
n
)
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
if
(
not
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
s
.
elements
()
==
n
;
}))
MIGRAPHX_THROW
(
prefix
()
+
"Wrong number of elements"
);
return
*
this
;
}
...
...
@@ -288,7 +288,8 @@ struct check_shapes
*/
const
check_shapes
&
batch_not_transposed
()
const
{
if
(
!
this
->
all_of
([
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
if
(
not
this
->
all_of
(
[
&
](
const
shape
&
s
)
{
return
batch_not_transposed_strides
(
s
.
strides
());
}))
MIGRAPHX_THROW
(
prefix
()
+
"Batch size is transposed"
);
return
*
this
;
}
...
...
src/include/migraphx/concat_opt.hpp
View file @
7f97b8ef
...
...
@@ -183,7 +183,7 @@ struct concat_optimization
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
))
{
...
...
@@ -233,7 +233,7 @@ struct concat_optimization
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/include/migraphx/context.hpp
View file @
7f97b8ef
...
...
@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{
return
{};
}
template
<
class
T
>
void
wait_for_context
(
T
&
,
any_ptr
)
{
}
template
<
class
T
>
void
finish_on_context
(
T
&
,
any_ptr
)
{
}
#ifdef TYPE_ERASED_DECLARATION
...
...
@@ -78,6 +87,10 @@ struct context
void
from_value
(
const
value
&
v
);
// (optional)
any_ptr
get_queue
();
// (optional)
void
wait_for
(
any_ptr
queue
);
// (optional)
void
finish_on
(
any_ptr
queue
);
//
void
finish
()
const
;
};
...
...
@@ -165,6 +178,18 @@ struct context
return
(
*
this
).
private_detail_te_get_handle
().
get_queue
();
}
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
...
...
@@ -187,6 +212,8 @@ struct context
virtual
value
to_value
()
const
=
0
;
virtual
void
from_value
(
const
value
&
v
)
=
0
;
virtual
any_ptr
get_queue
()
=
0
;
virtual
void
wait_for
(
any_ptr
queue
)
=
0
;
virtual
void
finish_on
(
any_ptr
queue
)
=
0
;
virtual
void
finish
()
const
=
0
;
};
...
...
@@ -231,6 +258,33 @@ struct context
return
get_queue_context
(
private_detail_te_self
);
}
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
struct
private_detail_te_handle_type
:
private_detail_te_handle_base_type
{
...
...
@@ -246,9 +300,9 @@ struct context
template
<
typename
PrivateDetailTypeErasedU
=
PrivateDetailTypeErasedT
>
private_detail_te_handle_type
(
PrivateDetailTypeErasedT
value
,
typename
std
::
enable_if
<
!
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
typename
std
::
enable_if
<
not
std
::
is_reference
<
PrivateDetailTypeErasedU
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
private_detail_te_value
(
std
::
move
(
value
)
)
:
private_detail_te_value
(
value
)
{
}
...
...
@@ -277,6 +331,18 @@ struct context
return
private_detail_te_default_get_queue
(
char
(
0
),
private_detail_te_value
);
}
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
PrivateDetailTypeErasedT
private_detail_te_value
;
...
...
@@ -306,7 +372,7 @@ struct context
private_detail_te_handle_base_type
&
private_detail_te_get_handle
()
{
assert
(
private_detail_te_handle_mem_var
!=
nullptr
);
if
(
!
private_detail_te_handle_mem_var
.
unique
())
if
(
not
private_detail_te_handle_mem_var
.
unique
())
private_detail_te_handle_mem_var
=
private_detail_te_handle_mem_var
->
clone
();
return
*
private_detail_te_handle_mem_var
;
}
...
...
src/
targets/gpu/
include/migraphx/
gpu/asin
.hpp
→
src/include/migraphx/
execution_environment
.hpp
View file @
7f97b8ef
...
...
@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_ASIN
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_ASIN
_HPP
#ifndef MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#define MIGRAPHX_GUARD_
MIGRAPHLIB_EXECUTION_ENV
_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/asin.hpp>
#include <migraphx/any_ptr.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
hip_asin
:
unary_device
<
hip_asin
,
device
::
asin
>
struct
execution_environment
{
any_ptr
queue
=
any_ptr
{};
bool
async
=
false
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
#endif
/* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
src/include/migraphx/iterator.hpp
View file @
7f97b8ef
...
...
@@ -31,9 +31,9 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
EndIterator
>
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
!
it
.
_M_dereferenceable
())
auto
is_end
(
rank
<
2
>
,
Iterator
it
,
EndIterator
)
->
decltype
(
not
it
.
_M_dereferenceable
())
{
return
!
it
.
_M_dereferenceable
();
return
not
it
.
_M_dereferenceable
();
}
template
<
class
Iterator
,
class
EndIterator
>
...
...
src/include/migraphx/literal.hpp
View file @
7f97b8ef
...
...
@@ -45,6 +45,11 @@ struct literal : raw_data<literal>
{
literal
()
{}
/*!
* Empty literal with a specific shape type
*/
explicit
literal
(
shape
::
type_t
shape_type
)
:
m_shape
(
shape_type
,
{})
{}
template
<
class
U
,
class
T
=
deduce
<
U
>,
shape
::
type_t
ShapeType
=
shape
::
get_type
<
T
>
{}
>
literal
(
U
x
)
:
buffer
(
make_shared_array
<
char
>
(
sizeof
(
T
))),
m_shape
(
ShapeType
)
{
...
...
Prev
1
2
3
4
5
6
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment