Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
476ed17c
"backend/vscode:/vscode.git/clone" did not exist on "a01b112f7fb018131a77929bf8900c9878047c3e"
Unverified
Commit
476ed17c
authored
Aug 28, 2023
by
Brian Pickrell
Committed by
GitHub
Aug 28, 2023
Browse files
Merge branch 'develop' into rand_uniform
parents
f4f9d711
6f1c947f
Changes
96
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
413 additions
and
154 deletions
+413
-154
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+20
-1
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+1
-1
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+2
-2
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+3
-2
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+1
-1
src/include/migraphx/normalize_attributes.hpp
src/include/migraphx/normalize_attributes.hpp
+32
-1
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+1
-1
src/include/migraphx/op/if_op.hpp
src/include/migraphx/op/if_op.hpp
+1
-1
src/include/migraphx/op/loop.hpp
src/include/migraphx/op/loop.hpp
+3
-3
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+216
-66
src/instruction.cpp
src/instruction.cpp
+1
-1
src/memory_coloring.cpp
src/memory_coloring.cpp
+5
-3
src/module.cpp
src/module.cpp
+4
-5
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+21
-0
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+1
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+12
-9
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+2
-3
src/onnx/parse_randomuniform_ops.cpp
src/onnx/parse_randomuniform_ops.cpp
+1
-1
src/onnx/parse_slice.cpp
src/onnx/parse_slice.cpp
+82
-49
src/program.cpp
src/program.cpp
+4
-4
No files found.
src/eliminate_contiguous.cpp
View file @
476ed17c
...
@@ -35,6 +35,8 @@
...
@@ -35,6 +35,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
)
static
bool
try_compute_shape
(
instruction_ref
ins
,
static
bool
try_compute_shape
(
instruction_ref
ins
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
std
::
vector
<
module_ref
>&
mods
)
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins,
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
return
(
arg
==
ins
)
?
new_shape
:
arg
->
get_shape
();
});
});
if
(
not
try_compute_shape
(
output
,
input_shapes
,
mods
))
if
(
not
try_compute_shape
(
output
,
input_shapes
,
output
->
module_inputs
()
))
{
{
return
false
;
return
false
;
}
}
}
}
}
}
catch
(
const
std
::
exception
&
e
)
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Exception: "
<<
e
.
what
()
<<
std
::
endl
;
}
return
false
;
}
catch
(...)
catch
(...)
{
{
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"Unknown exception"
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
...
@@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
{
{
if
(
arg
->
name
()
!=
op_name
)
if
(
arg
->
name
()
!=
op_name
)
continue
;
continue
;
if
(
enabled
(
MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS
{}))
{
std
::
cout
<<
"eliminate_contiguous: "
;
m
.
debug_print
(
ins
);
}
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
...
...
src/fuse_pointwise.cpp
View file @
476ed17c
...
@@ -41,7 +41,7 @@ static literal get_scalar(instruction_ref ins)
...
@@ -41,7 +41,7 @@ static literal get_scalar(instruction_ref ins)
if
(
ins
->
name
()
==
"contiguous"
)
if
(
ins
->
name
()
==
"contiguous"
)
return
get_scalar
(
ins
->
inputs
().
front
());
return
get_scalar
(
ins
->
inputs
().
front
());
const
auto
&
s
=
ins
->
get_shape
();
const
auto
&
s
=
ins
->
get_shape
();
if
(
s
.
elements
()
!=
1
&&
not
(
s
.
scalar
()))
if
(
s
.
elements
()
!=
1
and
not
(
s
.
scalar
()))
return
{};
return
{};
if
(
not
ins
->
can_eval
())
if
(
not
ins
->
can_eval
())
return
{};
return
{};
...
...
src/fuse_reduce.cpp
View file @
476ed17c
...
@@ -52,7 +52,7 @@ struct fused_reduce
...
@@ -52,7 +52,7 @@ struct fused_reduce
{
{
if
(
mods
.
size
()
!=
1
)
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
MIGRAPHX_THROW
(
"should have one submodule."
);
auto
*
sm
=
mods
.
front
();
const
auto
*
sm
=
mods
.
front
();
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
MIGRAPHX_THROW
(
"Only one output supported"
);
auto
names
=
sm
->
get_parameter_names
();
auto
names
=
sm
->
get_parameter_names
();
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
...
@@ -143,7 +143,7 @@ insert_module_in_submodule(module_ref sm,
}
}
static
std
::
vector
<
instruction_ref
>
static
std
::
vector
<
instruction_ref
>
find_inputs
(
module_ref
sm
,
find_inputs
(
const_
module_ref
sm
,
const
module
&
parent
,
const
module
&
parent
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
map_ins
)
{
{
...
...
src/include/migraphx/algorithm.hpp
View file @
476ed17c
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <numeric>
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -100,8 +102,7 @@ inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
...
@@ -100,8 +102,7 @@ inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
std
::
vector
<
size_t
>
d
(
l2
+
1
);
std
::
vector
<
size_t
>
d
(
l2
+
1
);
for
(
size_t
j
=
1
;
j
<=
l2
;
j
++
)
std
::
iota
(
d
.
begin
(),
d
.
end
(),
0
);
d
[
j
]
=
j
;
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
for
(
size_t
i
=
1
;
i
<=
l1
;
i
++
)
{
{
...
...
src/include/migraphx/check_shapes.hpp
View file @
476ed17c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
...
src/include/migraphx/normalize_attributes.hpp
View file @
476ed17c
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <cstring>
#include <cstring>
#include <vector>
#include <vector>
#include <migraphx/op/normalize_attribute.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -42,6 +43,36 @@ struct select_dependent_type
...
@@ -42,6 +43,36 @@ struct select_dependent_type
template
<
class
T
,
class
...
Ts
>
template
<
class
T
,
class
...
Ts
>
using
dependent_type
=
typename
select_dependent_type
<
T
,
Ts
...
>::
type
;
using
dependent_type
=
typename
select_dependent_type
<
T
,
Ts
...
>::
type
;
/**
* Used to normalize variable input axes at model runtime.
* Example: the axes inputs of the slice operator.
*
* \param axes the axes to normalize
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std
::
vector
<
int64_t
>
normalize_axes
(
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
=
""
);
/**
* Used to normalize variable input axes at model runtime.
* Example: the starts and ends inputs of the slice operator.
*
* \param indices the indices to normalize
* \param axes which axes the indices apply over
* \param input_shape shape of the input tensor
* \param attr_val the normalize_axes attributes from the operator
* \param prefix error message prefix
*/
std
::
vector
<
int64_t
>
normalize_indices
(
const
std
::
vector
<
int64_t
>&
indices
,
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
=
""
);
MIGRAPHX_EXPORT
MIGRAPHX_EXPORT
bool
normalize_attributes
(
operation
&
op
,
const
shape
&
input_shape
);
bool
normalize_attributes
(
operation
&
op
,
const
shape
&
input_shape
);
...
...
src/include/migraphx/op/convolution.hpp
View file @
476ed17c
...
@@ -82,7 +82,7 @@ struct convolution
...
@@ -82,7 +82,7 @@ struct convolution
const
auto
input_ndim
=
inputs
[
0
].
ndim
();
const
auto
input_ndim
=
inputs
[
0
].
ndim
();
const
auto
padding_size
=
padding
.
size
();
const
auto
padding_size
=
padding
.
size
();
if
(
input_ndim
!=
padding_size
/
2
+
2
&&
input_ndim
!=
padding_size
+
2
)
if
(
input_ndim
!=
padding_size
/
2
+
2
and
input_ndim
!=
padding_size
+
2
)
{
{
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
MIGRAPHX_THROW
(
"CONVOLUTION: input and attribute size mismatch!"
);
}
}
...
...
src/include/migraphx/op/if_op.hpp
View file @
476ed17c
...
@@ -71,7 +71,7 @@ struct if_op
...
@@ -71,7 +71,7 @@ struct if_op
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
std
::
set
<
std
::
string
>
pnames
;
std
::
set
<
std
::
string
>
pnames
;
for
(
const
auto
&
smod
:
mods
)
for
(
const
_module_ref
smod
:
mods
)
{
{
auto
names
=
smod
->
get_parameter_names
();
auto
names
=
smod
->
get_parameter_names
();
pnames
.
insert
(
names
.
begin
(),
names
.
end
());
pnames
.
insert
(
names
.
begin
(),
names
.
end
());
...
...
src/include/migraphx/op/loop.hpp
View file @
476ed17c
...
@@ -59,7 +59,7 @@ struct loop
...
@@ -59,7 +59,7 @@ struct loop
MIGRAPHX_THROW
(
"LOOP: operator should have one submodule."
);
MIGRAPHX_THROW
(
"LOOP: operator should have one submodule."
);
}
}
const
auto
&
mod
=
mods
.
front
();
const
_module_ref
mod
=
mods
.
front
();
auto
mod_out_shapes
=
mod
->
get_output_shapes
();
auto
mod_out_shapes
=
mod
->
get_output_shapes
();
auto
dep_param_num
=
inputs
.
size
()
-
2
;
auto
dep_param_num
=
inputs
.
size
()
-
2
;
...
...
src/include/migraphx/op/slice.hpp
View file @
476ed17c
...
@@ -27,19 +27,34 @@
...
@@ -27,19 +27,34 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Slice operator that accepts variable axes, starts and ends.
*
* Attributes:
* axes: constant axes to slice over (optional)
* starts: constant slice starting indices (optional)
* ends: constant slice ending indices (optional)
*
* Parameters:
* data: the input tensor to slice (dynamic or static shape)
* input_starts: starting indicies of slice (optional, static shape)
* input_ends: ending indicies of slice (optional, static shape)
* input_axes: axes to slice over (optional, static shape)
*/
struct
slice
struct
slice
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
{}
;
std
::
vector
<
int64_t
>
starts
;
std
::
vector
<
int64_t
>
starts
{}
;
std
::
vector
<
int64_t
>
ends
;
std
::
vector
<
int64_t
>
ends
{}
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -48,8 +63,8 @@ struct slice
...
@@ -48,8 +63,8 @@ struct slice
}
}
/**
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are
in
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are
* limits.
*
within
limits.
*/
*/
value
attributes
()
const
value
attributes
()
const
{
{
...
@@ -70,100 +85,235 @@ struct slice
...
@@ -70,100 +85,235 @@ struct slice
std
::
string
name
()
const
{
return
"slice"
;
}
std
::
string
name
()
const
{
return
"slice"
;
}
auto
compute_offset
(
const
shape
&
s
)
const
/**
{
* Computes the slice output shape dimensions for given starts, ends,and axes.
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
* Templated to also handle tensor views.
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
* Possibily different type between [in_starts, in_ends] and [in_axes] if in_axes is this
auto
offset
=
0
;
* object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid.
if
(
not
axes
.
empty
())
*/
{
template
<
class
A
,
class
B
>
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
std
::
vector
<
std
::
size_t
>
{
lens_calc
(
const
std
::
vector
<
std
::
size_t
>&
lengths
,
A
in_starts
,
A
in_ends
,
B
in_axes
)
const
auto
axis
=
axes
[
i
];
offset
+=
starts
[
i
]
*
strides
[
axis
];
}
}
else
{
{
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
auto
new_lens
=
lengths
;
for
(
std
::
size_t
i
=
0
;
i
<
in_axes
.
size
();
++
i
)
{
{
offset
+=
starts
[
axis
]
*
strides
[
axis
];
auto
axis
=
in_axes
[
i
];
new_lens
[
axis
]
=
in_ends
[
i
]
-
in_starts
[
i
];
}
}
}
return
new_lens
;
return
offset
;
}
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
3
,
4
);
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
if
(
inputs
.
size
()
==
1
)
{
auto
t
=
input_shape
.
type
();
auto
t
=
input_shape
.
type
();
// TODO: When support for dynamic shapes is added to normalize_attributes,
// remove this restriction.
if
(
input_shape
.
dynamic
()
and
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
if
(
input_shape
.
dynamic
()
and
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
not
input_shape
.
dyn_dims
()[
axis
].
is_fixed
();
return
not
input_shape
.
dyn_dims
()[
axis
].
is_fixed
();
}))
}))
{
{
MIGRAPHX_THROW
(
"SLICE: slicing is not allowed on non-fixed dynamic input axis "
);
MIGRAPHX_THROW
(
"SLICE: slicing is not allowed on non-fixed dynamic input axis "
);
}
}
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and optimals if possible.
std
::
vector
<
std
::
size_t
>
new_mins
;
std
::
vector
<
std
::
size_t
>
old_lens
;
std
::
vector
<
std
::
size_t
>
old_strides
;
// Doesn't handle optimals
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
old_lens
=
input_shape
.
max_lens
();
return
shape
{
t
,
new_mins
=
input_shape
.
min_lens
();
lens_calc
(
input_shape
.
min_lens
(),
starts
,
ends
,
axes
),
lens_calc
(
input_shape
.
max_lens
(),
starts
,
ends
,
axes
),
{}};
}
else
{
return
shape
{
t
,
lens_calc
(
input_shape
.
lens
(),
starts
,
ends
,
axes
),
input_shape
.
strides
()};
}
}
else
{
// check that starts, ends, and optionally input_axes are all 1D, have the same
// dimension, and are static
check_shapes
{
inputs
.
begin
()
+
1
,
inputs
.
end
(),
std
::
string
(
"SLICE: inputs (starts, ends, and input_axes)"
),
false
}
.
only_dims
(
1
)
.
same_dims
();
auto
dds
=
input_shape
.
to_dynamic
().
dyn_dims
();
if
(
inputs
.
size
()
==
3
)
{
if
(
inputs
[
1
].
lens
().
at
(
0
)
!=
axes
.
size
())
{
MIGRAPHX_THROW
(
"SLICE: inputs starts and ends do not have the same dimension "
"as the axes attribute"
);
}
std
::
for_each
(
axes
.
cbegin
(),
axes
.
cend
(),
[
&
](
const
auto
&
axis
)
{
dds
.
at
(
axis
)
=
{
0
,
dds
.
at
(
axis
).
max
};
});
}
}
else
else
{
{
old_lens
=
input_shape
.
lens
();
// if axes is an input, then all the output dimensions could be 0 to the max value
// For static shape (including during eval step after a dynamic input) the strides are
std
::
transform
(
dds
.
begin
(),
dds
.
end
(),
dds
.
begin
(),
[](
auto
dd
)
{
// indexed into the pre-slice array, so they are larger than the apparent size of the
return
shape
::
dynamic_dimension
{
0
,
dd
.
max
};
// resulting shape.
});
old_strides
=
input_shape
.
strides
();
}
return
shape
{
input_shape
.
type
(),
dds
};
}
}
}
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
/**
* Calculates the starting offset for the sliced tensor.
* Used in compute when only data input and all other information are in the attributes.
*
* \param s static input shape
*/
auto
compute_offset
(
const
shape
&
s
)
const
{
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
const
std
::
vector
<
std
::
size_t
>&
strides
=
s
.
strides
();
auto
offset
=
0
;
if
(
not
axes
.
empty
())
{
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
{
auto
axis
=
axes
[
i
];
auto
axis
=
axes
[
i
];
size_t
sliced_length
=
ends
[
i
]
-
starts
[
i
];
offset
+=
starts
[
i
]
*
strides
[
axis
];
// A Numpy indexing convention: a slice size larger than the actual dimension
}
// is legal and the "ends" value is clipped to the axis size
}
new_lens
[
axis
]
=
std
::
min
(
new_lens
[
axis
],
sliced_length
);
else
if
(
input_shape
.
dynamic
())
{
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
{
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
offset
+=
starts
[
axis
]
*
strides
[
axis
];
// sliced_length, making use of TBD start/end values.
std
::
size_t
sliced_min_length
=
ends
[
i
]
-
starts
[
i
];
// if the slice size is smaller than maxes but larger than mins
new_mins
[
axis
]
=
std
::
min
(
sliced_min_length
,
new_mins
[
axis
]);
}
}
}
}
if
(
input_shape
.
dynamic
())
return
offset
*
s
.
type_size
();
}
/**
* Calculates the starting offset for the sliced tensor (for aliasing).
* Used when the starts and/or the axes are inputs.
*
* \param s static input shape
* \param input_starts starting indices of slice
* \param ax_vec axes to slice on
*/
template
<
class
IndView
,
class
Axes
>
auto
compute_offset
(
const
shape
&
s
,
const
IndView
&
input_starts
,
const
Axes
&
ax_vec
)
const
{
{
return
shape
{
t
,
new_mins
,
new_lens
,
{}};
auto
ret
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
ax_vec
.
size
();
++
i
)
{
auto
axis
=
ax_vec
[
i
];
ret
+=
input_starts
[
i
]
*
s
.
strides
().
at
(
axis
);
}
}
else
return
ret
*
s
.
type_size
();
}
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>
normalize_inputs
(
const
shape
&
input_shape
,
const
std
::
vector
<
int64_t
>&
input_starts
,
const
std
::
vector
<
int64_t
>&
input_ends
)
const
{
{
return
shape
{
t
,
new_lens
,
old_strides
};
auto
attrs
=
this
->
attributes
().
at
(
"normalize_axes"
);
return
{{
"input_starts"
,
normalize_indices
(
input_starts
,
this
->
axes
,
input_shape
,
attrs
.
at
(
"starts"
),
"Slice variable input_starts"
)},
{
"input_ends"
,
normalize_indices
(
input_ends
,
this
->
axes
,
input_shape
,
attrs
.
at
(
"ends"
),
"Slice variable input_ends"
)}};
}
}
/**
* Three input version of the normalize_inputs.
* This one also checks that the input_axes are valid.
*/
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int64_t
>>
normalize_inputs
(
shape
input_shape
,
const
std
::
vector
<
int64_t
>&
input_starts
,
const
std
::
vector
<
int64_t
>&
input_ends
,
const
std
::
vector
<
int64_t
>&
input_axes
)
const
{
auto
attrs
=
this
->
attributes
().
at
(
"normalize_axes"
);
auto
norm_axes
=
normalize_axes
(
input_axes
,
input_shape
,
attrs
.
at
(
"axes"
),
"Slice variable input_axes"
);
return
{{
"input_starts"
,
normalize_indices
(
input_starts
,
norm_axes
,
input_shape
,
attrs
.
at
(
"starts"
),
"Slice variable input_starts"
)},
{
"input_ends"
,
normalize_indices
(
input_ends
,
norm_axes
,
input_shape
,
attrs
.
at
(
"ends"
),
"Slice variable input ends"
)},
{
"input_axes"
,
norm_axes
}};
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
input
=
args
[
0
];
auto
input
=
args
[
0
];
auto
input_shape
=
input
.
get_shape
();
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
dyn_out
.
computed_shape
.
type_size
();
switch
(
args
.
size
())
{
case
1
:
{
std
::
size_t
offset
=
compute_offset
(
input_shape
);
return
{
dyn_out
.
computed_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
return
{
dyn_out
.
computed_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
}
case
3
:
{
shape
calc_shape
;
std
::
size_t
offset
=
0
;
visit_all
(
args
[
1
],
args
[
2
])([
&
](
auto
input_starts
,
auto
input_ends
)
{
auto
norm_inputs
=
normalize_inputs
(
input_shape
,
input_starts
.
template
to_vector
<
int64_t
>(),
input_ends
.
template
to_vector
<
int64_t
>());
offset
=
compute_offset
(
input_shape
,
norm_inputs
.
at
(
"input_starts"
),
this
->
axes
);
calc_shape
=
{
input_shape
.
type
(),
lens_calc
(
input_shape
.
lens
(),
norm_inputs
.
at
(
"input_starts"
),
norm_inputs
.
at
(
"input_ends"
),
this
->
axes
),
input_shape
.
strides
()};
});
return
{
calc_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
case
4
:
{
shape
calc_shape
;
std
::
size_t
offset
=
0
;
visit_all
(
args
[
1
],
args
[
2
],
args
[
3
])(
[
&
](
auto
input_starts
,
auto
input_ends
,
auto
input_axes
)
{
auto
norm_inputs
=
normalize_inputs
(
input_shape
,
input_starts
.
template
to_vector
<
int64_t
>(),
input_ends
.
template
to_vector
<
int64_t
>(),
input_axes
.
template
to_vector
<
int64_t
>());
offset
=
compute_offset
(
input_shape
,
norm_inputs
.
at
(
"input_starts"
),
norm_inputs
.
at
(
"input_axes"
));
calc_shape
=
shape
{
input_shape
.
type
(),
lens_calc
(
input_shape
.
lens
(),
norm_inputs
.
at
(
"input_starts"
),
norm_inputs
.
at
(
"input_ends"
),
norm_inputs
.
at
(
"input_axes"
)),
input_shape
.
strides
()};
});
return
{
calc_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
default:
{
// Should never get here; covering in case some code change occurs
MIGRAPHX_THROW
(
"SLICE: invalid number of inputs"
);
}
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/instruction.cpp
View file @
476ed17c
...
@@ -389,7 +389,7 @@ void instruction::print(std::ostream& os,
...
@@ -389,7 +389,7 @@ void instruction::print(std::ostream& os,
if
(
not
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
{
{
std
::
string
delim
=
", ["
;
std
::
string
delim
=
", ["
;
for
(
auto
&
&
mod_arg
:
ins
->
module_inputs
())
for
(
const
const_module_ref
&
mod_arg
:
ins
->
module_inputs
())
{
{
os
<<
delim
<<
mod_arg
->
name
();
os
<<
delim
<<
mod_arg
->
name
();
delim
=
", "
;
delim
=
", "
;
...
...
src/memory_coloring.cpp
View file @
476ed17c
...
@@ -23,9 +23,9 @@
...
@@ -23,9 +23,9 @@
*/
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
...
@@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
s
)},
{
"offset"
,
offset
}}),
mem
);
}
}
// Replace zero allocation
// Replace zero allocation
...
@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
...
@@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const
if
(
ins
->
name
()
!=
allocation_op
)
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
m
.
replace_instruction
(
ins
,
make_op
(
"load"
,
{{
"shape"
,
to_value
(
ins
->
get_shape
())},
{
"offset"
,
0
}}),
mem
);
}
}
// Remove scratch parameter if its not used
// Remove scratch parameter if its not used
...
...
src/module.cpp
View file @
476ed17c
...
@@ -873,12 +873,11 @@ module::print_py(std::ostream& os,
...
@@ -873,12 +873,11 @@ module::print_py(std::ostream& os,
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
os
<<
mname
<<
".add_literal("
;
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
const
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
// Disable abs for now
use_abs
=
false
;
// ins->get_literal().visit([&](auto v) {
// use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
// });
if
(
use_abs
)
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_argument("
;
os
<<
"migraphx.generate_argument("
;
...
...
src/normalize_attributes.cpp
View file @
476ed17c
...
@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
...
@@ -49,6 +49,10 @@ auto tune_attribute(const std::vector<int64_t>& vec,
Message
m
)
Message
m
)
{
{
std
::
vector
<
int64_t
>
result
(
vec
);
std
::
vector
<
int64_t
>
result
(
vec
);
if
(
result
.
empty
())
{
return
result
;
};
int64_t
n_rank
=
input_shape
.
ndim
();
int64_t
n_rank
=
input_shape
.
ndim
();
std
::
vector
<
op
::
normalize_attribute
>
vec_attrs
=
val
.
to_vector
<
op
::
normalize_attribute
>
();
std
::
vector
<
op
::
normalize_attribute
>
vec_attrs
=
val
.
to_vector
<
op
::
normalize_attribute
>
();
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
use_output
))
if
(
contains
(
vec_attrs
,
op
::
normalize_attribute
::
use_output
))
...
@@ -251,5 +255,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
...
@@ -251,5 +255,22 @@ bool normalize_attributes(operation& op, const shape& input_shape)
return
tuned
;
return
tuned
;
}
}
std
::
vector
<
int64_t
>
normalize_axes
(
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
)
{
return
tune_attribute
(
axes
,
{},
attr_val
,
input_shape
,
[
&
]
{
return
prefix
;
});
}
std
::
vector
<
int64_t
>
normalize_indices
(
const
std
::
vector
<
int64_t
>&
indices
,
const
std
::
vector
<
int64_t
>&
axes
,
const
shape
&
input_shape
,
const
value
&
attr_val
,
const
std
::
string
&
prefix
)
{
return
tune_attribute
(
indices
,
axes
,
attr_val
,
input_shape
,
[
&
]
{
return
prefix
;
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
476ed17c
...
@@ -117,6 +117,7 @@ struct onnx_parser
...
@@ -117,6 +117,7 @@ struct onnx_parser
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
parse_graph
(
module
*
mod
,
const
onnx
::
GraphProto
&
graph
,
bool
inlining
=
false
);
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
shape
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
;
};
};
...
...
src/onnx/onnx_parser.cpp
View file @
476ed17c
...
@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
...
@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
}
}
shape
s
;
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
if
(
parser
.
map_input_dims
.
count
(
name
)
>
0
)
{
{
dims
=
parser
.
map_input_dims
.
at
(
name
);
std
::
vector
<
std
::
size_t
>
dims
=
parser
.
map_input_dims
.
at
(
name
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
s
=
parser
.
parse_type
(
input
.
type
(),
dims
);
}
}
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
else
if
(
parser
.
map_dyn_input_dims
.
count
(
name
)
>
0
)
...
@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
...
@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
}
}
else
else
{
{
s
=
parser
.
parse_type
(
input
.
type
()
,
dims
);
s
=
parser
.
parse_type
(
input
.
type
());
}
}
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
...
@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
}
}
MIGRAPHX_THROW
(
"PARSE_TENSOR: Invalid tensor type"
);
MIGRAPHX_THROW
(
"PARSE_TENSOR: Invalid tensor type"
);
}
}
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
,
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
)
const
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
not
input_dims
.
empty
())
{
return
{
shape_type
,
input_dims
};
}
std
::
vector
<
shape
::
dynamic_dimension
>
dynamic_dims
;
std
::
vector
<
shape
::
dynamic_dimension
>
dynamic_dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
...
@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return
shape_from_dyn_dims
(
shape_type
,
dynamic_dims
);
return
shape_from_dyn_dims
(
shape_type
,
dynamic_dims
);
}
}
shape
onnx_parser
::
parse_type
(
const
onnx
::
TypeProto
&
t
,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
input_dims
.
empty
())
return
{
shape_type
};
return
{
shape_type
,
input_dims
};
}
shape
::
type_t
get_type
(
int
dtype
)
shape
::
type_t
get_type
(
int
dtype
)
{
{
switch
(
dtype
)
switch
(
dtype
)
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
476ed17c
...
@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -55,9 +55,6 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val
=
literal
({
shape
::
float_type
,
{
1
},
{
0
}},
{
0.0
f
});
l_val
=
literal
({
shape
::
float_type
,
{
1
},
{
0
}},
{
0.0
f
});
}
}
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
if
(
args
.
empty
())
if
(
args
.
empty
())
{
{
MIGRAPHX_THROW
(
"ConstantOfShape : must have 1 input!"
);
MIGRAPHX_THROW
(
"ConstantOfShape : must have 1 input!"
);
...
@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -65,6 +62,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
else
else
{
{
migraphx
::
shape
s
;
migraphx
::
shape
s
;
// input is empty, output is a scalar
auto
type
=
l_val
.
get_shape
().
type
();
// empty input tensor, output is a scalar
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
{
{
...
...
src/onnx/parse_randomuniform_ops.cpp
View file @
476ed17c
...
@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
...
@@ -96,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if
(
contains
(
info
.
attributes
,
"seed"
))
if
(
contains
(
info
.
attributes
,
"seed"
))
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
gen
.
seed
(
info
.
attributes
.
at
(
"seed"
).
f
());
std
::
uniform_real_distribution
<>
d
(
high
,
low
);
std
::
uniform_real_distribution
<>
d
(
low
,
high
);
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
vector
<
double
>
rand_vals
(
out_shape
.
elements
());
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
std
::
generate
(
rand_vals
.
begin
(),
rand_vals
.
end
(),
[
&
]()
{
return
d
(
gen
);
});
...
...
src/onnx/parse_slice.cpp
View file @
476ed17c
...
@@ -34,16 +34,65 @@ namespace onnx {
...
@@ -34,16 +34,65 @@ namespace onnx {
struct
parse_slice
:
op_parser
<
parse_slice
>
struct
parse_slice
:
op_parser
<
parse_slice
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Slice"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Slice"
}};
}
struct
slice_desc
{
op
::
slice
op
;
std
::
vector
<
instruction_ref
>
op_args
;
std
::
vector
<
int64_t
>
steps
;
std
::
vector
<
int64_t
>
raxes
;
void
always_insert
(
instruction_ref
arg
)
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
std
::
vector
<
int64_t
>
insert
(
instruction_ref
arg
)
{
std
::
vector
<
int64_t
>
result
;
migraphx
::
argument
arg_value
=
arg
->
eval
();
if
(
arg_value
.
empty
())
{
op_args
.
insert
(
op_args
.
begin
(),
arg
);
}
else
{
arg_value
.
visit
([
&
](
auto
s
)
{
result
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
return
result
;
}
};
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
parser
,
const
onnx_parser
&
parser
,
const
onnx_parser
::
node_info
&
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
auto
sd
=
construct_slice_desc
(
parser
,
info
,
args
);
auto
ins
=
info
.
add_instruction
(
sd
.
op
,
sd
.
op_args
);
if
(
not
sd
.
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
sd
.
raxes
}}),
ins
);
}
// If any steps are other than default 1, add a "steps" op
if
(
std
::
any_of
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
transform
(
sd
.
steps
.
begin
(),
sd
.
steps
.
end
(),
std
::
back_inserter
(
nsteps
),
[](
auto
s
)
{
return
std
::
abs
(
s
);
});
return
ins
=
info
.
add_instruction
(
make_op
(
"step"
,
{{
"axes"
,
sd
.
op
.
axes
},
{
"steps"
,
nsteps
}}),
ins
);
}
else
return
ins
;
}
slice_desc
construct_slice_desc
(
const
onnx_parser
&
parser
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
op
::
slice
op
;
slice_desc
sd
;
std
::
vector
<
int64_t
>
steps
;
// slice can have up to 5 inputs, we first check the 5th one
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice.
// to decide whether MIGRAPHX can handle this slice.
...
@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
...
@@ -51,89 +100,73 @@ struct parse_slice : op_parser<parse_slice>
{
{
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
migraphx
::
argument
step_arg
=
args
.
back
()
->
eval
();
check_arg_empty
(
step_arg
,
"PARSE_SLICE: cannot handle variable steps for slice"
);
check_arg_empty
(
step_arg
,
"PARSE_SLICE: cannot handle variable steps for slice"
);
step_arg
.
visit
([
&
](
auto
s
)
{
steps
.
assign
(
s
.
begin
(),
s
.
end
());
});
step_arg
.
visit
([
&
](
auto
s
)
{
sd
.
steps
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
if
(
args
.
size
()
>=
4
)
if
(
args
.
size
()
>=
4
)
{
{
migraphx
::
argument
axes_arg
=
args
.
at
(
3
)
->
eval
();
sd
.
op
.
axes
=
sd
.
insert
(
args
.
at
(
3
));
check_arg_empty
(
axes_arg
,
"PARSE_SLICE: cannot handle variable axes for slice"
);
axes_arg
.
visit
([
&
](
auto
s
)
{
op
.
axes
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
else
if
(
contains
(
info
.
attributes
,
"axes"
))
else
if
(
contains
(
info
.
attributes
,
"axes"
))
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axes"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axes"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
axes
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
axes
));
});
}
}
if
(
args
.
size
()
>=
3
)
if
(
args
.
size
()
>=
3
)
{
{
migraphx
::
argument
end_arg
=
args
.
at
(
2
)
->
eval
();
sd
.
op
.
ends
=
sd
.
insert
(
args
.
at
(
2
));
check_arg_empty
(
end_arg
,
"PARSE_SLICE: cannot handle variable ends for slice"
);
end_arg
.
visit
([
&
](
auto
s
)
{
op
.
ends
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
else
if
(
contains
(
info
.
attributes
,
"ends"
))
else
if
(
contains
(
info
.
attributes
,
"ends"
))
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"ends"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"ends"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
ends
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
ends
));
});
}
}
if
(
args
.
size
()
>=
2
)
if
(
args
.
size
()
>=
2
)
{
{
migraphx
::
argument
start_arg
=
args
.
at
(
1
)
->
eval
();
sd
.
op
.
starts
=
sd
.
insert
(
args
.
at
(
1
));
check_arg_empty
(
start_arg
,
"PARSE_SLICE: cannot handle variable starts for slice"
);
start_arg
.
visit
([
&
](
auto
s
)
{
op
.
starts
.
assign
(
s
.
begin
(),
s
.
end
());
});
}
}
else
if
(
contains
(
info
.
attributes
,
"starts"
))
else
if
(
contains
(
info
.
attributes
,
"starts"
))
{
{
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"starts"
));
literal
s
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"starts"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
sd
.
op
.
starts
));
});
}
}
// data input argument
sd
.
always_insert
(
args
.
at
(
0
));
// If axes arg is not given, the default is all of them.
// If axes arg is not given, the default is all of them.
if
(
op
.
axes
.
empty
())
if
(
sd
.
op
.
axes
.
empty
()
and
sd
.
op_args
.
size
()
<
3
)
{
{
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
vector
<
int64_t
>
axes
(
args
[
0
]
->
get_shape
().
ndim
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
int64_t
{
0
});
op
.
axes
=
axes
;
sd
.
op
.
axes
=
axes
;
}
}
std
::
vector
<
int64_t
>
raxes
;
if
(
not
sd
.
steps
.
empty
())
{
if
(
sd
.
op
.
starts
.
empty
()
or
sd
.
op
.
ends
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable starts and ends is not supported"
);
if
(
sd
.
op
.
axes
.
empty
())
MIGRAPHX_THROW
(
"PARSE_SLICE: steps and variable axes is not supported"
);
}
assert
(
steps
.
empty
()
or
steps
.
size
()
==
op
.
axes
.
size
());
assert
(
sd
.
steps
.
empty
()
or
sd
.
steps
.
size
()
==
sd
.
op
.
axes
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
starts
.
size
());
assert
(
op
.
axes
.
size
()
==
op
.
ends
.
size
());
// If any axes have negative step, prepare to add a "reverse" op
// If any axes have negative step, prepare to add a "reverse" op
for
(
auto
i
:
range
(
steps
.
size
()))
for
(
auto
i
:
range
(
sd
.
steps
.
size
()))
{
{
if
(
steps
[
i
]
>=
0
)
if
(
sd
.
steps
[
i
]
>=
0
)
continue
;
continue
;
op
.
starts
[
i
]
+=
1
;
sd
.
op
.
starts
[
i
]
+=
1
;
if
(
op
.
starts
[
i
]
==
0
)
if
(
sd
.
op
.
starts
[
i
]
==
0
)
op
.
starts
[
i
]
=
INT_MAX
;
sd
.
op
.
starts
[
i
]
=
INT_MAX
;
op
.
ends
[
i
]
+=
1
;
sd
.
op
.
ends
[
i
]
+=
1
;
raxes
.
push_back
(
op
.
axes
[
i
]);
sd
.
raxes
.
push_back
(
sd
.
op
.
axes
[
i
]);
std
::
swap
(
op
.
starts
[
i
],
op
.
ends
[
i
]);
std
::
swap
(
sd
.
op
.
starts
[
i
],
sd
.
op
.
ends
[
i
]);
}
auto
ins
=
info
.
add_instruction
(
op
,
args
[
0
]);
if
(
not
raxes
.
empty
())
{
ins
=
info
.
add_instruction
(
make_op
(
"reverse"
,
{{
"axes"
,
raxes
}}),
ins
);
}
}
// If any steps are other than default 1, add a "steps" op
return
sd
;
if
(
std
::
any_of
(
steps
.
begin
(),
steps
.
end
(),
[](
auto
s
)
{
return
std
::
abs
(
s
)
!=
1
;
}))
{
std
::
vector
<
int64_t
>
nsteps
;
std
::
transform
(
steps
.
begin
(),
steps
.
end
(),
std
::
back_inserter
(
nsteps
),
[](
auto
s
)
{
return
std
::
abs
(
s
);
});
return
ins
=
info
.
add_instruction
(
make_op
(
"step"
,
{{
"axes"
,
op
.
axes
},
{
"steps"
,
nsteps
}}),
ins
);
}
else
return
ins
;
}
}
};
};
...
...
src/program.cpp
View file @
476ed17c
...
@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
...
@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots
// Gather all the target roots
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
std
::
unordered_multimap
<
std
::
size_t
,
module_ref
>
roots
;
auto
mods
=
this
->
get_modules
();
auto
mods
=
this
->
get_modules
();
for
(
auto
*
mod
:
mods
)
for
(
const
auto
*
mod
:
mods
)
{
{
for
(
const
auto
&
ins
:
*
mod
)
for
(
const
auto
&
ins
:
*
mod
)
{
{
...
@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
...
@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out
[
x
]
=
ss
.
str
();
ins_out
[
x
]
=
ss
.
str
();
});
});
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
ret
=
generic_eval
(
*
this
,
contexts
,
std
::
move
(
params
),
[
&
](
instruction_ref
ins
,
auto
f
)
{
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
const
auto
&
ctx
=
contexts
[
ins
->
get_target_id
()];
ctx
.
finish
();
ctx
.
finish
();
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
std
::
cout
<<
"Run instruction: "
<<
ins_out
.
at
(
ins
)
<<
std
::
endl
;
timer
t
{};
timer
t
{};
...
@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
...
@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
std
::
back_inserter
(
module_inputs
),
std
::
back_inserter
(
module_inputs
),
[
&
](
const
value
&
i
)
{
return
map_mods
.
at
(
i
.
to
<
std
::
string
>
());
});
[
&
](
const
value
&
i
)
{
return
map_mods
.
at
(
i
.
to
<
std
::
string
>
());
});
for
(
auto
&
smod
:
module_inputs
)
for
(
const
auto
&
smod
:
module_inputs
)
{
{
mod_from_val
(
smod
,
v
,
instructions
,
map_mods
);
mod_from_val
(
smod
,
v
,
instructions
,
map_mods
);
}
}
...
@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
...
@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
std
::
vector
<
module
*>
unused
;
std
::
vector
<
module
*>
unused
;
generic_get_unused_modules
(
generic_get_unused_modules
(
impl
->
modules
,
generic_get_modules
(
this
->
get_main_module
()),
std
::
back_inserter
(
unused
));
impl
->
modules
,
generic_get_modules
(
this
->
get_main_module
()),
std
::
back_inserter
(
unused
));
for
(
auto
*
m
:
unused
)
for
(
const
auto
*
m
:
unused
)
this
->
remove_module
(
m
->
name
());
this
->
remove_module
(
m
->
name
());
}
}
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment