Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9bff4331
Commit
9bff4331
authored
Mar 21, 2023
by
Paul
Browse files
Merge
parents
214b313f
94a7f6ee
Changes
274
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1074 additions
and
148 deletions
+1074
-148
src/include/migraphx/op/reverse.hpp
src/include/migraphx/op/reverse.hpp
+2
-0
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+66
-21
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+146
-0
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+69
-28
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+5
-5
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+63
-29
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+76
-41
src/include/migraphx/op/where.hpp
src/include/migraphx/op/where.hpp
+13
-5
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+2
-0
src/include/migraphx/optimize_module.hpp
src/include/migraphx/optimize_module.hpp
+16
-6
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-0
src/include/migraphx/register_op.hpp
src/include/migraphx/register_op.hpp
+22
-1
src/include/migraphx/register_target.hpp
src/include/migraphx/register_target.hpp
+15
-1
src/include/migraphx/replace_allocate.hpp
src/include/migraphx/replace_allocate.hpp
+3
-0
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+29
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+16
-0
src/instruction.cpp
src/instruction.cpp
+18
-0
src/memory_coloring.cpp
src/memory_coloring.cpp
+405
-0
src/module.cpp
src/module.cpp
+90
-0
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+17
-10
No files found.
src/include/migraphx/op/reverse.hpp
View file @
9bff4331
...
...
@@ -28,6 +28,7 @@
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp>
...
...
@@ -60,6 +61,7 @@ struct reverse
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
return
inputs
[
0
].
with_lens
(
inputs
[
0
].
lens
());
}
...
...
src/include/migraphx/op/scatternd_op.hpp
View file @
9bff4331
...
...
@@ -28,44 +28,89 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* @brief
* N-dimensional Scatter operations. This struct is parent class to ops which differ in what formula
* is used to reduce (combine old and new values of) the scattered value. It was originally based
* on Onnx ScatterND operation (see
* https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterND) and is also similar to Numpy
* numpy.add.at().
*
* @tparam Derived a template parameter in the CRTP inheritance idiom, represents one of the child
* operations.
*/
template
<
class
Derived
>
struct
scatternd_op
:
op_name
<
Derived
>
{
/** Validate input shapes and return the correct output shape. For Scatter ops, the output
* is the same shape as the data tensor (first input), but cast to a standard shape.
*
*/
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
auto
r
=
inputs
.
front
().
lens
().
size
();
auto
q
=
inputs
.
at
(
1
).
lens
().
size
();
auto
k
=
inputs
.
at
(
1
).
lens
().
back
();
auto
ind_lens
=
inputs
.
at
(
1
).
lens
();
auto
upd_lens
=
inputs
.
back
().
lens
();
auto
data_lens
=
inputs
.
front
().
lens
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
);
auto
data_shape
=
inputs
.
front
();
auto
index_shape
=
inputs
.
at
(
1
);
auto
upd_shape
=
inputs
.
back
();
auto
r
=
data_shape
.
ndim
();
auto
q
=
index_shape
.
ndim
();
size_t
k
;
if
(
index_shape
.
dynamic
())
{
// the rank of the output is a function of k, so k must be fixed.
if
(
not
index_shape
.
dyn_dims
().
back
().
is_fixed
())
{
MIGRAPHX_THROW
(
"GATHERND: last dimension of indices tensor must be fixed (min=max)"
);
}
k
=
index_shape
.
dyn_dims
().
back
().
min
;
}
else
k
=
index_shape
.
lens
().
back
();
// Checks on the sizes of input tensors
if
(
q
+
r
!=
upd_shape
.
ndim
()
+
k
+
1
)
MIGRAPHX_THROW
(
"ScatterND: ranks of inputs don't match. "
+
std
::
to_string
(
q
)
+
" + "
+
std
::
to_string
(
r
)
+
" - "
+
std
::
to_string
(
k
)
+
" - 1 != "
+
std
::
to_string
(
upd_shape
.
ndim
()));
if
(
k
>
r
)
MIGRAPHX_THROW
(
"ScatterND: index of size "
+
std
::
to_string
(
k
)
+
" is too large for tensor of rank "
+
std
::
to_string
(
r
));
if
(
not
(
std
::
equal
(
ind_lens
.
begin
(),
ind_lens
.
begin
()
+
q
-
1
,
upd_lens
.
begin
())
and
std
::
equal
(
data_lens
.
begin
()
+
k
,
data_lens
.
end
(),
upd_lens
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]"
);
auto
s
=
inputs
.
front
();
if
(
s
.
broadcasted
())
// Convert all static shape dimensions to dynamic so they can be compared.
// It's possible for some of the 3 inputs to be dynamic shapes and some static,
// but any dynamic dimension that's compared to a static dimension must be fixed.
auto
ind_dims
=
index_shape
.
to_dynamic
().
dyn_dims
();
auto
upd_dims
=
upd_shape
.
to_dynamic
().
dyn_dims
();
auto
data_dims
=
data_shape
.
to_dynamic
().
dyn_dims
();
// Check that corresponding portions of tensor shapes match.
if
(
not
(
std
::
equal
(
ind_dims
.
begin
(),
ind_dims
.
begin
()
+
q
-
1
,
upd_dims
.
begin
())
and
std
::
equal
(
data_dims
.
begin
()
+
k
,
data_dims
.
end
(),
upd_dims
.
begin
()
+
q
-
1
)))
MIGRAPHX_THROW
(
"ScatterND: incorrect update shape. Update dimensions must match "
"indices and data."
);
if
(
data_shape
.
dynamic
())
return
data_shape
;
else
if
(
data_shape
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
return
{
data_shape
.
type
(),
data_shape
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
return
data_shape
.
with_lens
(
data_shape
.
lens
());
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
&
self
=
static_cast
<
const
Derived
&>
(
*
this
);
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
updates
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
...
...
@@ -74,8 +119,8 @@ struct scatternd_op : op_name<Derived>
auto
updates_std
=
shape
{
updates_shape
.
type
(),
updates_shape
.
lens
()};
auto
indices_shape
=
indices
.
get_shape
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
lens
().
size
();
auto
r
=
out
put_shape
.
lens
().
size
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
dyn_out
.
com
put
ed
_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
...
...
@@ -89,7 +134,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
index_start
,
index_end
,
out_idx
.
begin
());
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
out
put_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
com
put
ed
_shape
.
index
(
out_idx
)],
updates
[
i
]);
});
});
});
...
...
src/include/migraphx/op/select_module.hpp
0 → 100644
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SELECT_MODULE_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
select_module
{
shape
output_dyn_shapes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
output_dyn_shapes
,
"output_dyn_shapes"
));
}
std
::
string
name
()
const
{
return
"select_module"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
module_ref
>&
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has_at_least
(
1
);
return
shape
{
output_dyn_shapes
};
}
std
::
vector
<
std
::
string
>
get_input_parameter_names
(
module_ref
mod
)
const
{
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
ret
;
std
::
copy_if
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
not
contains
(
pn
,
"#output_"
);
});
return
ret
;
}
std
::
vector
<
std
::
string
>
get_output_parameter_names
(
module_ref
mod
)
const
{
auto
param_names
=
mod
->
get_parameter_names
();
std
::
vector
<
std
::
string
>
ret
;
std
::
copy_if
(
param_names
.
cbegin
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
contains
(
pn
,
"#output_"
);
});
return
ret
;
}
argument
compute
(
const
shape
&
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
module_ref
>&
submodule_list
,
const
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>&
run
)
const
{
// Find submodule with input parameter shapes exactly the same as the input instruction
// arguments. Assuming instruction arguments are in the same order as the instruction
// parameters.
auto
module_iter
=
std
::
find_if
(
submodule_list
.
cbegin
(),
submodule_list
.
cend
(),
[
&
](
module_ref
mr
)
{
auto
in_param_names
=
get_input_parameter_names
(
mr
);
auto
param_shapes
=
mr
->
get_parameter_shapes
();
assert
(
in_param_names
.
size
()
<=
args
.
size
());
return
std
::
equal
(
in_param_names
.
cbegin
(),
in_param_names
.
cend
(),
args
.
cbegin
(),
[
&
](
auto
p_name
,
auto
a
)
{
return
a
.
get_shape
()
==
param_shapes
[
p_name
];
});
});
if
(
module_iter
==
submodule_list
.
end
())
{
MIGRAPHX_THROW
(
"SELECT_MODULE: no compatible submodules found for given input shapes"
);
}
auto
*
module_to_run
=
*
module_iter
;
std
::
unordered_map
<
std
::
string
,
argument
>
p_map
;
// add input parameters to parameter_map
auto
in_param_names
=
get_input_parameter_names
(
module_to_run
);
assert
(
in_param_names
.
size
()
<=
args
.
size
());
std
::
transform
(
in_param_names
.
begin
(),
in_param_names
.
end
(),
args
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
return
std
::
make_pair
(
name
,
a
);
});
// One tuple output parameter in main module to multiple output parameters in submodule
auto
out_param_names
=
get_output_parameter_names
(
module_to_run
);
auto
output_sub_objects
=
args
.
back
().
get_sub_objects
();
assert
(
out_param_names
.
size
()
==
output_sub_objects
.
size
());
std
::
transform
(
out_param_names
.
begin
(),
out_param_names
.
end
(),
output_sub_objects
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
auto
ps
=
module_to_run
->
get_parameter_shape
(
name
);
if
(
a
.
get_shape
()
!=
ps
)
{
assert
(
ps
.
bytes
()
==
a
.
get_shape
().
bytes
());
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
}
else
{
return
std
::
make_pair
(
name
,
a
);
}
});
auto
results
=
run
(
module_to_run
,
p_map
);
return
argument
{
results
};
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/slice.hpp
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
...
...
@@ -46,6 +47,10 @@ struct slice
return
pack
(
f
(
self
.
axes
,
"axes"
),
f
(
self
.
starts
,
"starts"
),
f
(
self
.
ends
,
"ends"
));
}
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value
attributes
()
const
{
value
normalize
=
value
::
object
{};
...
...
@@ -65,14 +70,6 @@ struct slice
std
::
string
name
()
const
{
return
"slice"
;
}
auto
fix_index
(
const
std
::
vector
<
std
::
size_t
>&
lens
,
std
::
size_t
axis
,
int64_t
index
)
const
{
int64_t
r
=
std
::
min
(
index
,
static_cast
<
int64_t
>
(
lens
[
axis
]));
if
(
r
<
0
)
r
+=
lens
[
axis
];
return
std
::
size_t
(
r
);
}
auto
compute_offset
(
const
shape
&
s
)
const
{
const
std
::
vector
<
std
::
size_t
>&
lens
=
s
.
lens
();
...
...
@@ -83,14 +80,14 @@ struct slice
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
offset
+=
fix_index
(
lens
,
axis
,
starts
[
i
]
)
*
strides
[
axis
];
offset
+=
starts
[
i
]
*
strides
[
axis
];
}
}
else
{
for
(
std
::
size_t
axis
=
0
;
axis
<
lens
.
size
();
axis
++
)
{
offset
+=
fix_index
(
lens
,
axis
,
starts
[
axis
]
)
*
strides
[
axis
];
offset
+=
starts
[
axis
]
*
strides
[
axis
];
}
}
return
offset
;
...
...
@@ -98,37 +95,81 @@ struct slice
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
const
auto
&
old_lens
=
input_shape
.
lens
();
const
auto
&
old_strides
=
input_shape
.
strides
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
t
=
input_shape
.
type
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
i
)
{
return
(
i
>=
old_lens
.
size
()
and
i
<
0
);
}))
// TODO: When support for dynamic shapes is added to normalize_attributes,
// remove this restriction.
if
(
input_shape
.
dynamic
()
and
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
not
input_shape
.
dyn_dims
()[
axis
].
is_fixed
();
}))
{
MIGRAPHX_THROW
(
"SLICE:
input axis "
+
to_string_range
(
axes
)
+
" out of range
"
);
MIGRAPHX_THROW
(
"SLICE:
slicing is not allowed on non-fixed dynamic input axis
"
);
}
if
(
starts
.
size
()
!=
axes
.
size
()
or
axes
.
size
()
!=
ends
.
size
())
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std
::
vector
<
std
::
size_t
>
new_mins
;
std
::
vector
<
std
::
size_t
>
new_opts
;
std
::
vector
<
std
::
size_t
>
old_lens
;
std
::
vector
<
std
::
size_t
>
old_strides
;
if
(
input_shape
.
dynamic
())
{
old_lens
=
input_shape
.
max_lens
();
new_mins
=
input_shape
.
min_lens
();
new_opts
=
input_shape
.
opt_lens
();
}
else
{
MIGRAPHX_THROW
(
"SLICE: inconsistent sizes"
);
old_lens
=
input_shape
.
lens
();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides
=
input_shape
.
strides
();
}
std
::
vector
<
std
::
size_t
>
new_lens
=
old_lens
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
auto
axis
=
axes
[
i
];
new_lens
[
axis
]
=
fix_index
(
old_lens
,
axis
,
ends
[
i
])
-
fix_index
(
old_lens
,
axis
,
starts
[
i
]);
auto
axis
=
axes
[
i
];
size_t
sliced_length
=
ends
[
i
]
-
starts
[
i
];
// A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens
[
axis
]
=
std
::
min
(
new_lens
[
axis
],
sliced_length
);
if
(
input_shape
.
dynamic
())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std
::
size_t
sliced_min_length
=
ends
[
i
]
-
starts
[
i
];
// if the slice size is smaller than maxes but larger than mins
new_mins
[
axis
]
=
std
::
min
(
sliced_min_length
,
new_mins
[
axis
]);
auto
sliced_opt_length
=
ends
[
i
]
-
starts
[
i
];
if
(
new_opts
[
axis
]
!=
0
)
new_opts
[
axis
]
=
sliced_opt_length
;
if
(
new_opts
[
axis
]
<
new_mins
[
axis
]
or
new_opts
[
axis
]
>
new_lens
[
axis
])
new_opts
[
axis
]
=
0
;
}
}
if
(
input_shape
.
dynamic
())
{
return
shape
{
t
,
new_mins
,
new_lens
,
new_opts
};
}
else
{
return
shape
{
t
,
new_lens
,
old_strides
};
}
return
shape
{
t
,
new_lens
,
old_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
output_shape
.
type_size
();
return
{
std
::
move
(
output_shape
),
[
=
]
{
return
input
.
data
()
+
offset
;
}};
auto
input
=
args
[
0
];
auto
offset
=
compute_offset
(
input
.
get_shape
())
*
dyn_out
.
computed_shape
.
type_size
();
return
{
dyn_out
.
computed_shape
,
[
=
]
{
return
input
.
data
()
+
offset
;
}};
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/softmax.hpp
View file @
9bff4331
...
...
@@ -53,15 +53,15 @@ struct softmax
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
if
(
inputs
.
at
(
0
).
packed
())
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
()
or
s0
.
packed
())
{
return
inputs
.
at
(
0
)
;
return
s0
;
}
else
{
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
inputs
.
at
(
0
).
type
(),
lens
};
return
{
s0
.
type
(),
s0
.
lens
()};
}
}
...
...
src/include/migraphx/op/squeeze.hpp
View file @
9bff4331
...
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -54,52 +55,85 @@ struct squeeze
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
if
(
input_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
vector
<
std
::
size_t
>
new_strides
;
if
(
axes
.
empty
())
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
input_shape
.
dyn_dims
()[
axis
]
!=
1
;
}))
{
MIGRAPHX_THROW
(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}"
);
}
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
if
(
axes
.
empty
())
{
std
::
copy_if
(
input_shape
.
dyn_dims
().
cbegin
(),
input_shape
.
dyn_dims
().
cend
(),
std
::
back_inserter
(
dyn_dims
),
[
&
](
auto
dd
)
{
return
dd
!=
1
;
});
}
else
{
if
(
old_lens
[
i
]
!=
1
)
for
(
auto
i
:
range
(
input_shape
.
ndim
())
)
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
dyn_dims
.
push_back
(
input_shape
.
dyn_dims
()[
i
]);
}
}
}
return
{
input_shape
.
type
(),
dyn_dims
};
}
else
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
MIGRAPHX_THROW
(
"SQUEEZE: static axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
vector
<
std
::
size_t
>
new_strides
;
if
(
axes
.
empty
())
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
if
(
old_lens
[
i
]
!=
1
)
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
}
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
,
new_strides
};
else
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
,
new_strides
};
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
9bff4331
...
...
@@ -29,11 +29,20 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct
unsqueeze
{
std
::
vector
<
int64_t
>
axes
;
...
...
@@ -56,63 +65,89 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
if
(
input_shape
.
dynamic
())
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
return
shape
{
type
,
old_lens
};
else
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
if
(
not
steps
.
empty
())
{
MIGRAPHX_THROW
(
"UNSQUEEZE_dyn: nonempty steps attribute"
);
}
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
auto
new_ndim
=
input_shape
.
ndim
()
+
axes
.
size
();
std
::
size_t
k
=
0
;
for
(
auto
i
:
range
(
new_ndim
))
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
dyn_dims
.
push_back
({
1
,
1
,
0
});
}
else
{
dyn_dims
.
push_back
(
input_shape
.
dyn_dims
().
at
(
k
++
));
}
}
return
{
input_shape
.
type
(),
dyn_dims
};
}
else
{
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
return
shape
{
type
,
old_lens
};
else
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_strides
(
new_size
);
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
{
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_strides
(
new_size
);
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
{
std
::
int64_t
step
=
1
;
if
(
axis_idx
<
steps
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
std
::
int64_t
step
=
1
;
if
(
axis_idx
<
steps
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
else
{
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
else
{
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
new_lens
[
i
]
=
old_lens
[
p
];
new_strides
[
i
]
=
old_strides
[
p
++
];
}
}
else
{
new_lens
[
i
]
=
old_lens
[
p
];
new_strides
[
i
]
=
old_strides
[
p
++
];
}
return
shape
{
type
,
new_lens
,
new_strides
};
}
return
shape
{
type
,
new_lens
,
new_strides
};
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
...
...
src/include/migraphx/op/where.hpp
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -42,9 +42,17 @@ struct where
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_dims
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
3
).
same_dims
();
auto
s1
=
inputs
.
at
(
1
);
auto
s2
=
inputs
.
at
(
2
);
if
(
s1
.
dynamic
()
or
s2
.
dynamic
())
{
if
(
s1
==
s2
)
return
s1
;
MIGRAPHX_THROW
(
"WHERE: dynamic input shapes must be the same"
);
}
// Compare two static shapes, returning a standard shape
if
(
s1
==
s2
and
s1
.
packed
())
{
return
s1
;
...
...
@@ -63,12 +71,12 @@ struct where
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
1
],
args
[
2
])([
&
](
auto
output
,
const
auto
x
,
const
auto
y
)
{
args
[
0
].
visit
([
&
](
const
auto
condition
)
{
par_for
(
out
put_shape
.
elements
(),
par_for
(
dyn_out
.
com
put
ed
_shape
.
elements
(),
[
&
](
auto
i
)
{
output
[
i
]
=
condition
[
i
]
?
x
[
i
]
:
y
[
i
];
});
});
});
...
...
src/include/migraphx/operation.hpp
View file @
9bff4331
...
...
@@ -140,6 +140,8 @@ template <class T>
auto
compute_shape_op
(
rank
<
2
>
,
const
T
&
x
,
const
std
::
vector
<
shape
>&
inputs
)
->
decltype
(
x
.
normalize_compute_shape
(
inputs
))
{
if
(
inputs
.
empty
())
MIGRAPHX_THROW
(
"At least one input is required for "
+
x
.
name
());
dependent_type
<
operation
,
T
>
y
=
x
;
normalize_attributes
(
y
,
inputs
[
0
].
max_lens
());
return
any_cast
<
T
>
(
y
).
normalize_compute_shape
(
inputs
);
...
...
src/include/migraphx/
pass_config
.hpp
→
src/include/migraphx/
optimize_module
.hpp
View file @
9bff4331
...
...
@@ -21,18 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#ifndef MIGRAPHX_GUARD_PASS_CONFIG_HPP
#define MIGRAPHX_GUARD_PASS_CONFIG_HPP
#include <migraphx/env.hpp>
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_MEMORY_COLORING
)
struct
module_pass_manager
;
/**
* Runs several passes in a loop
*/
struct
optimize_module
{
std
::
string
name
()
const
{
return
"optimize_module"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_PASS_CONFIG_HPP
#endif
src/include/migraphx/program.hpp
View file @
9bff4331
...
...
@@ -115,6 +115,7 @@ struct program
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
...
...
src/include/migraphx/register_op.hpp
View file @
9bff4331
...
...
@@ -33,15 +33,36 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// unregister all ops for specified target, useful when unloading dynamically plugged-in target lib
void
unregister_op
(
const
std
::
string
&
op_name
);
namespace
detail
{
struct
op_handler
{
operation
op
;
std
::
string
name
;
op_handler
(
const
operation
&
op_r
)
:
op
(
op_r
),
name
(
op
.
name
()){};
~
op_handler
()
{
unregister_op
(
name
);
}
};
}
// namespace detail
void
register_op_init
();
void
register_op
(
const
operation
&
op
);
operation
load_op
(
const
std
::
string
&
name
);
bool
has_op
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_operators
();
template
<
class
T
>
void
register_op
()
{
register_op
(
T
{});
register_op_init
();
// instantiate static op_map;
static
auto
op_h
=
detail
::
op_handler
(
T
{});
register_op
(
op_h
.
op
);
}
struct
register_op_action
...
...
src/include/migraphx/register_target.hpp
View file @
9bff4331
...
...
@@ -33,14 +33,28 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
register_target_init
();
void
register_target
(
const
target
&
t
);
void
unregister_target
(
const
std
::
string
&
name
);
target
make_target
(
const
std
::
string
&
name
);
std
::
vector
<
std
::
string
>
get_targets
();
namespace
detail
{
struct
target_handler
{
target
t
;
std
::
string
target_name
;
target_handler
(
const
target
&
t_r
)
:
t
(
t_r
),
target_name
(
t
.
name
())
{}
~
target_handler
()
{
unregister_target
(
target_name
);
}
};
}
// namespace detail
template
<
class
T
>
void
register_target
()
{
register_target
(
T
{});
register_target_init
();
static
auto
t_h
=
detail
::
target_handler
(
T
{});
register_target
(
t_h
.
t
);
}
struct
register_target_action
...
...
src/include/migraphx/replace_allocate.hpp
View file @
9bff4331
...
...
@@ -32,6 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
/**
* Replace `allocate` instructions with target allocations or output parameters.
*/
struct
replace_allocate
{
allocation_model
model
;
...
...
src/include/migraphx/serialize.hpp
View file @
9bff4331
...
...
@@ -93,7 +93,11 @@ auto to_value_impl(rank<4>, const optional<T>& x)
{
value
result
{};
if
(
x
.
has_value
())
<<<<<<<
HEAD
to_value
(
*
x
);
=======
return
to_value
(
*
x
);
>>>>>>>
develop
return
result
;
}
...
...
@@ -207,6 +211,7 @@ void from_value_impl(rank<5>, const value& v, T& x)
template
<
class
T
>
void
from_value_impl
(
rank
<
6
>
,
const
value
&
v
,
optional
<
T
>&
x
)
<<<<<<<
HEAD
{
if
(
not
v
.
is_null
())
x
=
from_value
<
T
>
(
v
);
...
...
@@ -214,26 +219,45 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{})
>
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
=======
>>>>>>>
develop
{
x
=
v
.
to
<
T
>
();
if
(
not
v
.
is_null
())
x
=
from_value
<
T
>
(
v
);
}
<<<<<<<
HEAD
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_enum
<
T
>{})
>
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
T
&
x
)
=======
template
<
class
T
,
MIGRAPHX_REQUIRES
(
std
::
is_arithmetic
<
T
>{}
or
std
::
is_enum
<
T
>
{})
>
void
from_value_impl
(
rank
<
7
>
,
const
value
&
v
,
T
&
x
)
>>>>>>>
develop
{
x
=
v
.
to
<
T
>
();
}
<<<<<<<
HEAD
inline
void
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
=======
inline
void
from_value_impl
(
rank
<
8
>
,
const
value
&
v
,
std
::
string
&
x
)
{
x
=
v
.
to
<
std
::
string
>
();
}
template
<
class
T
>
auto
from_value_impl
(
rank
<
9
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
from_value
(
v
),
void
())
>>>>>>>
develop
{
x
.
from_value
(
v
);
}
template
<
class
T
>
<<<<<<<
HEAD
auto
from_value_impl
(
rank
<
11
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
=======
auto
from_value_impl
(
rank
<
10
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
migraphx_from_value
(
v
,
x
),
void
())
>>>>>>>
develop
{
migraphx_from_value
(
v
,
x
);
}
...
...
@@ -249,7 +273,11 @@ value to_value(const T& x)
template
<
class
T
>
void
from_value
(
const
value
&
v
,
T
&
x
)
{
<<<<<<<
HEAD
detail
::
from_value_impl
(
rank
<
11
>
{},
v
,
x
);
=======
detail
::
from_value_impl
(
rank
<
10
>
{},
v
,
x
);
>>>>>>>
develop
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/shape.hpp
View file @
9bff4331
...
...
@@ -101,6 +101,19 @@ struct shape
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
// compare to fixed std::size_t dimension
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
// add and subtract fixed std::size_t dimension
dynamic_dimension
&
operator
+=
(
const
std
::
size_t
&
x
);
dynamic_dimension
&
operator
-=
(
const
std
::
size_t
&
x
);
friend
dynamic_dimension
operator
+
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
dynamic_dimension
operator
-
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
};
static
const
std
::
vector
<
type_t
>&
types
();
...
...
@@ -230,6 +243,9 @@ struct shape
/// Return true if the shape is dynamic
bool
dynamic
()
const
;
/// Return true if this shape or any of the sub_shapes are dynamic
bool
any_of_dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
with_lens
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
)
const
;
...
...
src/instruction.cpp
View file @
9bff4331
...
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
}
bool
instruction
::
is_undefined
()
const
{
if
(
op
.
name
()
==
"undefined"
)
{
return
true
;
}
else
if
(
this
->
inputs
().
empty
())
{
return
false
;
}
else
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
is_undefined
();
});
}
}
bool
instruction
::
can_eval
()
const
{
if
(
op
.
name
()
==
"@literal"
)
...
...
src/memory_coloring.cpp
0 → 100644
View file @
9bff4331
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DEBUG_MEMORY_COLORING
);
using
instruction_set
=
std
::
unordered_set
<
instruction_ref
>
;
using
instruction_set_map
=
std
::
unordered_map
<
instruction_ref
,
instruction_set
>
;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template
<
class
F
>
void
liveness
(
const
module
&
m
,
F
f
)
{
auto
implicit_deps
=
m
.
calc_implicit_deps
();
instruction_set
live_set
;
auto
rp
=
reverse
(
m
);
for
(
auto
rins
:
iterator_for
(
rp
))
// NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto
ins
=
std
::
prev
(
rins
.
base
());
// Add live variables
auto
add_live_variables
=
[
&
](
const
auto
&
inputs
)
{
for
(
auto
input
:
inputs
)
{
auto
i
=
instruction
::
get_output_alias
(
input
);
// Skip if variable comes from parent
if
(
not
m
.
has_instruction
(
i
))
continue
;
live_set
.
insert
(
i
);
}
};
add_live_variables
(
ins
->
inputs
());
add_live_variables
(
implicit_deps
[
ins
]);
// Remove last usage
auto
it
=
live_set
.
find
(
ins
);
if
(
it
!=
live_set
.
end
())
{
live_set
.
erase
(
it
);
f
(
ins
,
live_set
);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map
build_conflict_table
(
const
module
&
m
,
std
::
string
allocation_op
)
{
instruction_set_map
conflict_table
;
liveness
(
m
,
[
&
](
auto
ins
,
auto
live_set
)
{
// Skip variables that aren't allocations
if
(
ins
->
name
()
!=
allocation_op
)
return
;
// Skip zero allocations
if
(
ins
->
get_shape
().
bytes
()
==
0
)
return
;
conflict_table
[
ins
];
for
(
auto
i
:
live_set
)
{
if
(
i
==
ins
)
continue
;
// Skip variables that aren't allocations
if
(
i
->
name
()
!=
allocation_op
)
continue
;
// Skip zero allocations
if
(
i
->
get_shape
().
bytes
()
==
0
)
continue
;
conflict_table
[
i
].
insert
(
ins
);
conflict_table
[
ins
].
insert
(
i
);
}
});
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[](
auto
&&
pp
)
{
return
pp
.
second
.
count
(
pp
.
first
)
==
0
;
}));
return
conflict_table
;
}
// Check if intervals overlap
bool
is_overlap
(
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
x
,
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
y
)
{
return
std
::
max
(
x
.
first
,
y
.
first
)
<
std
::
min
(
x
.
second
,
y
.
second
);
}
struct
allocation_segment
{
using
segment
=
std
::
pair
<
std
::
size_t
,
std
::
size_t
>
;
std
::
unordered_map
<
instruction_ref
,
segment
>
ins2segment
;
const
segment
*
add_segment
(
instruction_ref
ins
,
segment
s
)
{
return
&
(
ins2segment
[
ins
]
=
s
);
}
const
segment
*
get_segment
(
instruction_ref
ins
)
const
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
==
ins2segment
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
// Remove segment for an instruction
void
remove
(
instruction_ref
ins
)
{
auto
it
=
ins2segment
.
find
(
ins
);
if
(
it
!=
ins2segment
.
end
())
{
ins2segment
.
erase
(
it
);
}
}
std
::
size_t
max
()
{
std
::
size_t
n
=
0
;
for
(
auto
&&
pp
:
ins2segment
)
{
auto
seg
=
pp
.
second
;
n
=
std
::
max
(
n
,
seg
.
second
);
}
return
n
;
}
template
<
class
Iterator
>
static
bool
overlaps
(
Iterator
first
,
Iterator
last
,
const
segment
&
s
)
{
return
std
::
any_of
(
first
,
last
,
[
&
](
auto
&&
t
)
{
return
is_overlap
(
s
,
t
);
});
}
static
bool
overlaps
(
const
std
::
set
<
segment
>&
segments
,
const
segment
&
s
)
{
return
overlaps
(
segments
.
begin
(),
segments
.
end
(),
s
);
}
static
auto
find_gap
(
const
std
::
set
<
segment
>&
segments
,
std
::
size_t
n
)
{
std
::
size_t
max_end
=
0
;
return
std
::
adjacent_find
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
if
(
x
.
second
<
max_end
)
return
false
;
max_end
=
x
.
second
;
if
(
is_overlap
(
x
,
y
))
return
false
;
assert
(
y
.
first
>=
x
.
second
);
auto
k
=
y
.
first
-
x
.
second
;
return
(
k
>=
n
);
});
}
static
std
::
size_t
max_type_size
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
sub_shapes
().
begin
(),
s
.
sub_shapes
().
end
(),
s
.
type_size
(),
[](
auto
size
,
const
auto
&
sub
)
{
return
std
::
max
(
size
,
max_type_size
(
sub
));
});
}
static
std
::
size_t
compute_alignment
(
instruction_ref
ins
)
{
auto
alignment
=
max_type_size
(
ins
->
get_shape
());
// A rough estimate for the total number of elements
auto
n
=
ins
->
get_shape
().
bytes
()
/
alignment
;
// Check for vectorized alignment
if
(
n
>
4
)
{
auto
d
=
n
%
4
;
if
(
d
==
0
)
alignment
*=
4
;
if
(
d
==
2
)
alignment
*=
2
;
}
return
alignment
;
}
static
segment
next_segment
(
std
::
set
<
segment
>&
segments
,
instruction_ref
ins
,
std
::
size_t
alignment
)
{
assert
(
ins
->
get_shape
().
bytes
()
>
0
);
// Compute alignment
auto
n
=
1
+
(
ins
->
get_shape
().
bytes
()
-
1
)
/
alignment
;
assert
(
n
>
0
);
auto
start
=
0
;
// Insert at end if it cant fit at the begining
if
(
segments
.
empty
()
or
segments
.
begin
()
->
first
<=
n
)
{
auto
it
=
find_gap
(
segments
,
n
);
if
(
it
==
segments
.
end
())
it
=
std
::
max_element
(
segments
.
begin
(),
segments
.
end
(),
[
&
](
segment
x
,
segment
y
)
{
return
x
.
second
<
y
.
second
;
});
if
(
it
!=
segments
.
end
())
start
=
it
->
second
;
}
auto
s
=
segment
{
start
,
start
+
n
};
assert
(
not
overlaps
(
segments
,
s
));
segments
.
insert
(
s
);
return
s
;
}
static
std
::
unordered_map
<
instruction_ref
,
int
>
create_allocation_index
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
)
{
std
::
unordered_map
<
instruction_ref
,
int
>
result
;
int
i
=
0
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
not
contains
(
conflict_table
,
ins
))
continue
;
result
[
ins
]
=
i
++
;
}
return
result
;
}
// Build the allocation_color class from the conflict_table
static
allocation_segment
build
(
const
module
&
m
,
const
instruction_set_map
&
conflict_table
,
std
::
size_t
alignment
)
{
allocation_segment
as
{};
std
::
vector
<
instruction_ref
>
conflict_queue
;
// Add all allocations to the conflict_queue
std
::
transform
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
std
::
back_inserter
(
conflict_queue
),
[](
auto
&&
pp
)
{
return
pp
.
first
;
});
auto
alloc_index
=
create_allocation_index
(
m
,
conflict_table
);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std
::
sort
(
conflict_queue
.
begin
(),
conflict_queue
.
end
(),
by
(
std
::
greater
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
conflict_table
.
at
(
x
).
size
(),
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for
(
auto
parent
:
conflict_queue
)
{
// Sort children by size
std
::
vector
<
instruction_ref
>
children
(
conflict_table
.
at
(
parent
).
begin
(),
conflict_table
.
at
(
parent
).
end
());
std
::
sort
(
children
.
begin
(),
children
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
x
)
{
return
std
::
make_tuple
(
x
->
get_shape
().
bytes
(),
alloc_index
.
at
(
x
));
}));
assert
(
not
contains
(
children
,
parent
));
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
assert
(
as
.
get_segment
(
parent
)
==
nullptr
);
as
.
add_segment
(
parent
,
next_segment
(
segments
,
parent
,
alignment
));
}
// Reduce the number of segments
for
(
std
::
size_t
n
=
0
;
n
<
3
;
n
++
)
{
for
(
auto
parent
:
conflict_queue
)
{
auto
children
=
conflict_table
.
at
(
parent
);
// This set is to track the segments already processed
std
::
set
<
segment
>
segments
;
// Add all segments for the children to the segments already processed
transform_if
(
children
.
begin
(),
children
.
end
(),
std
::
inserter
(
segments
,
segments
.
begin
()),
[
&
](
auto
child
)
{
return
as
.
get_segment
(
child
);
},
[
&
](
auto
child
)
{
return
*
as
.
get_segment
(
child
);
});
// Get the segment for the parent
const
auto
*
parent_segment
=
as
.
get_segment
(
parent
);
assert
(
parent_segment
!=
nullptr
);
auto
s
=
next_segment
(
segments
,
parent
,
alignment
);
if
(
s
!=
*
parent_segment
and
s
.
second
<=
as
.
max
())
{
as
.
add_segment
(
parent
,
s
);
}
}
}
return
as
;
}
};
static
std
::
size_t
find_max_alignment
(
const
module
&
m
,
const
std
::
string
&
allocation_op
)
{
std
::
size_t
alignment
=
1
;
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
alignment
=
std
::
max
(
allocation_segment
::
compute_alignment
(
ins
),
alignment
);
}
return
alignment
;
}
void
memory_coloring
::
apply
(
module
&
m
)
const
{
const
std
::
size_t
alignment
=
find_max_alignment
(
m
,
allocation_op
);
auto
conflict_table
=
build_conflict_table
(
m
,
allocation_op
);
auto
as
=
allocation_segment
::
build
(
m
,
conflict_table
,
alignment
);
// All allocations should have a segment
assert
(
std
::
all_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
return
as
.
get_segment
(
pp
.
first
);
}));
// Adjacent allocations should not have overlapping segments
assert
(
std
::
none_of
(
conflict_table
.
begin
(),
conflict_table
.
end
(),
[
&
](
auto
&&
pp
)
{
auto
*
x
=
as
.
get_segment
(
pp
.
first
);
return
std
::
any_of
(
pp
.
second
.
begin
(),
pp
.
second
.
end
(),
[
&
](
auto
ins
)
{
auto
*
y
=
as
.
get_segment
(
ins
);
assert
(
x
and
y
);
return
is_overlap
(
*
x
,
*
y
);
});
}));
// Print out segments
if
(
enabled
(
MIGRAPHX_DEBUG_MEMORY_COLORING
{}))
{
for
(
auto
&&
pp
:
conflict_table
)
{
std
::
cout
<<
"------- conflict -------"
<<
std
::
endl
;
auto
s1
=
as
.
ins2segment
.
at
(
pp
.
first
);
std
::
cout
<<
s1
.
first
<<
", "
<<
s1
.
second
<<
": "
;
m
.
debug_print
(
pp
.
first
);
for
(
auto
ins
:
pp
.
second
)
{
auto
s2
=
as
.
ins2segment
.
at
(
ins
);
std
::
cout
<<
s2
.
first
<<
", "
<<
s2
.
second
<<
": "
;
m
.
debug_print
(
ins
);
}
}
}
// Total memory
std
::
size_t
n
=
as
.
max
()
*
alignment
;
// Replace allocations
auto
mem
=
m
.
add_parameter
(
"scratch"
,
shape
{
shape
::
int8_type
,
{
n
}});
for
(
auto
&&
[
ins
,
seg
]
:
as
.
ins2segment
)
{
assert
(
ins
->
name
()
==
allocation_op
);
auto
s
=
ins
->
get_shape
();
std
::
size_t
offset
=
seg
.
first
*
alignment
;
assert
(
offset
<
n
);
m
.
replace_instruction
(
ins
,
op
::
load
{
s
,
offset
},
mem
);
}
// Replace zero allocation
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
allocation_op
)
continue
;
assert
(
ins
->
get_shape
().
bytes
()
==
0
);
m
.
replace_instruction
(
ins
,
op
::
load
{
ins
->
get_shape
(),
0
},
mem
);
}
// Remove scratch parameter if its not used
if
(
mem
->
outputs
().
empty
())
{
m
.
remove_instruction
(
mem
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/module.cpp
View file @
9bff4331
...
...
@@ -166,6 +166,7 @@ void module::assign(const module& m)
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
}
else
if
(
ins
->
name
()
==
"@outline"
)
{
...
...
@@ -789,6 +790,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
...
...
@@ -804,6 +821,15 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape(type="
<<
to_json_string
(
s
.
type_string
())
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
...
@@ -813,6 +839,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
...
...
@@ -874,6 +962,8 @@ module::print_cpp(std::ostream& os,
return
names
;
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
src/normalize_attributes.cpp
View file @
9bff4331
...
...
@@ -30,13 +30,16 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
/**
* Parameters:
* vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
* normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
auto
tune_attribute
(
const
std
::
vector
<
int64_t
>&
vec
,
const
std
::
vector
<
int64_t
>&
axes
,
const
value
&
val
,
...
...
@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val)
return
result
;
}
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool
normalize_attributes
(
operation
&
op
,
const
std
::
vector
<
std
::
size_t
>&
lens
)
{
bool
tuned
=
false
;
...
...
@@ -158,9 +166,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto
val
=
op
.
to_value
();
if
(
attrs
.
contains
(
"normalize_padding"
))
{
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
// for now, assume the dimensions to pad start at dim 2
auto
padding
=
val
.
at
(
attrs
.
at
(
"normalize_padding"
).
to
<
std
::
string
>
());
auto
padding_size
=
padding
.
size
();
auto
padding_start
=
2
;
if
(
padding_size
==
2
*
(
lens
.
size
()
-
padding_start
))
...
...
Prev
1
2
3
4
5
6
7
8
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment