Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cd4ab535
Commit
cd4ab535
authored
Jun 20, 2023
by
Khalique Ahmed
Browse files
manual merge
parents
3891ee58
a0fa3742
Changes
279
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
454 additions
and
69 deletions
+454
-69
src/include/migraphx/op/quantizelinear.hpp
src/include/migraphx/op/quantizelinear.hpp
+23
-12
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+2
-4
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+117
-5
src/include/migraphx/op/run_on_target.hpp
src/include/migraphx/op/run_on_target.hpp
+98
-0
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+6
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+3
-10
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+8
-9
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+6
-0
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+2
-2
src/include/migraphx/process.hpp
src/include/migraphx/process.hpp
+3
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+3
-0
src/include/migraphx/promote_literals.hpp
src/include/migraphx/promote_literals.hpp
+47
-0
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+1
-1
src/include/migraphx/serialize.hpp
src/include/migraphx/serialize.hpp
+2
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+46
-14
src/include/migraphx/split_single_dyn_dim.hpp
src/include/migraphx/split_single_dyn_dim.hpp
+48
-0
src/include/migraphx/tf.hpp
src/include/migraphx/tf.hpp
+2
-0
src/include/migraphx/value.hpp
src/include/migraphx/value.hpp
+15
-2
src/instruction.cpp
src/instruction.cpp
+5
-1
src/module.cpp
src/module.cpp
+17
-6
No files found.
src/include/migraphx/op/quantizelinear.hpp
View file @
cd4ab535
...
@@ -38,9 +38,22 @@ namespace op {
...
@@ -38,9 +38,22 @@ namespace op {
struct
quantizelinear
struct
quantizelinear
{
{
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
std
::
string
name
()
const
{
return
"quantizelinear"
;
}
value
attributes
()
const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return
{{
"pointwise"
,
true
}};
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
same_dims
();
check_shapes
{
inputs
,
*
this
}.
same_dims
().
has
(
2
,
3
);
if
(
inputs
[
0
].
type
()
!=
inputs
[
1
].
type
())
{
MIGRAPHX_THROW
(
"QUANTIZELINEAR: Scales and input must be the same type"
);
}
if
(
inputs
.
size
()
==
3
)
if
(
inputs
.
size
()
==
3
)
{
{
return
{
inputs
[
2
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
return
{
inputs
[
2
].
type
(),
inputs
[
0
].
lens
(),
inputs
[
0
].
strides
()};
...
@@ -61,17 +74,15 @@ struct quantizelinear
...
@@ -61,17 +74,15 @@ struct quantizelinear
argument
result
{
output_shape
};
argument
result
{
output_shape
};
visit_all
(
result
,
y_zero_point
)([
&
](
auto
output
,
auto
zero_pts
)
{
visit_all
(
result
,
y_zero_point
)([
&
](
auto
output
,
auto
zero_pts
)
{
x
.
visit
([
&
](
auto
input
)
{
visit_all
(
x
,
y_scale
)([
&
](
auto
input
,
auto
scales
)
{
y_scale
.
visit
([
&
](
auto
scales
)
{
using
quant_type
=
typename
decltype
(
output
)
::
value_type
;
using
quant_type
=
typename
decltype
(
output
)
::
value_type
;
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
min_value
=
std
::
numeric_limits
<
quant_type
>::
min
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
auto
max_value
=
std
::
numeric_limits
<
quant_type
>::
max
();
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
round
(
input
[
i
]
/
scales
[
i
]))
+
int64_t
quantized
=
static_cast
<
int64_t
>
(
std
::
round
(
input
[
i
]
/
scales
[
i
]))
+
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
static_cast
<
int64_t
>
(
zero_pts
[
i
]);
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
output
[
i
]
=
std
::
max
(
static_cast
<
int64_t
>
(
min_value
),
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
std
::
min
(
static_cast
<
int64_t
>
(
max_value
),
quantized
));
});
});
});
});
});
});
});
...
...
src/include/migraphx/op/reduce_op.hpp
View file @
cd4ab535
...
@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived>
...
@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived>
{
{
value
normalize
;
value
normalize
;
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
return
{{
"normalize_axes"
,
normalize
}
,
{
"reduce"
,
true
}
};
}
}
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
...
@@ -123,9 +123,7 @@ struct reduce_op : op_name<Derived>
...
@@ -123,9 +123,7 @@ struct reduce_op : op_name<Derived>
auto
tuned_axes
=
tune_axes
(
output_dyn_dims
.
size
());
auto
tuned_axes
=
tune_axes
(
output_dyn_dims
.
size
());
for
(
const
auto
&
axis
:
tuned_axes
)
for
(
const
auto
&
axis
:
tuned_axes
)
{
{
// At the time of writing, there's no functional difference between
output_dyn_dims
[
axis
]
=
{
1
,
1
};
// optimum of 0 (no opt) or 1.
output_dyn_dims
[
axis
]
=
{
1
,
1
,
0
};
}
}
return
shape
{
s
.
type
(),
output_dyn_dims
};
return
shape
{
s
.
type
(),
output_dyn_dims
};
...
...
src/include/migraphx/op/reshape.hpp
View file @
cd4ab535
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -96,9 +97,115 @@ struct reshape
...
@@ -96,9 +97,115 @@ struct reshape
return
{
s0
.
type
(),
output_dyn_dims
};
return
{
s0
.
type
(),
output_dyn_dims
};
}
}
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>=
dim
;
});
if
(
x
!=
dim
)
return
start
;
return
it
;
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
auto
can_strides_merge
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
auto
cstride
=
*
std
::
prev
(
stride_last
);
return
std
::
equal
(
std
::
make_reverse_iterator
(
dim_last
),
std
::
make_reverse_iterator
(
dim_start
+
1
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
std
::
make_reverse_iterator
(
stride_start
),
[
&
](
auto
dim
,
auto
stride
)
{
cstride
*=
dim
;
return
stride
==
cstride
;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static
optional
<
shape
>
reshape_dims
(
const
shape
&
input
,
const
std
::
vector
<
std
::
size_t
>&
rdims
)
{
if
(
input
.
standard
())
return
shape
{
input
.
type
(),
rdims
};
const
auto
&
idims
=
input
.
lens
();
const
auto
&
istrides
=
input
.
strides
();
std
::
vector
<
std
::
size_t
>
rstrides
;
std
::
size_t
i
=
0
;
std
::
size_t
r
=
0
;
while
(
i
<
idims
.
size
()
and
r
<
rdims
.
size
())
{
auto
idim
=
idims
[
i
];
auto
rdim
=
rdims
[
r
];
if
(
rdim
==
idim
)
{
rstrides
.
push_back
(
istrides
[
i
]);
}
// squeeze
else
if
(
rdim
>
idim
)
{
auto
start
=
idims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
idims
.
end
(),
rdim
);
if
(
it
==
start
)
return
nullopt
;
auto
n
=
it
-
start
;
assert
((
i
+
n
)
<=
istrides
.
size
());
if
(
not
can_strides_merge
(
start
,
it
+
1
,
istrides
.
begin
()
+
i
,
istrides
.
begin
()
+
i
+
n
+
1
))
return
nullopt
;
i
+=
n
;
rstrides
.
push_back
(
istrides
[
i
]);
}
// unsqueeze
else
// if(rdim < idim)
{
auto
start
=
rdims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
rdims
.
end
(),
idim
);
if
(
it
==
start
)
return
nullopt
;
auto
n
=
it
-
start
;
assert
((
r
+
n
)
<=
rdims
.
size
());
auto
stride
=
istrides
[
i
]
*
idim
;
std
::
for_each
(
start
,
it
+
1
,
[
&
](
auto
dim
)
{
stride
/=
dim
;
rstrides
.
push_back
(
stride
);
});
r
+=
n
;
}
i
++
;
r
++
;
}
// Handle trailing 1s
if
(
rstrides
.
size
()
<
rdims
.
size
()
and
not
rstrides
.
empty
())
{
auto
stride
=
rstrides
.
back
();
for
(
auto
d
:
range
(
rdims
.
begin
()
+
rstrides
.
size
(),
rdims
.
end
()))
{
if
(
d
!=
1
)
return
nullopt
;
rstrides
.
push_back
(
stride
);
}
}
if
(
rdims
.
size
()
!=
rstrides
.
size
())
return
nullopt
;
return
shape
{
input
.
type
(),
rdims
,
rstrides
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
standard
(
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
...
@@ -125,12 +232,17 @@ struct reshape
...
@@ -125,12 +232,17 @@ struct reshape
}
}
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
auto
s
=
reshape_dims
(
inputs
.
front
(),
rdims
);
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
if
(
not
s
.
has_value
())
MIGRAPHX_THROW
(
"Reshape on axis that is not packed."
);
if
(
s
->
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Reshape: Wrong number of elements for reshape: reshape has "
+
MIGRAPHX_THROW
(
"Reshape: Wrong number of elements for reshape: reshape has "
+
std
::
to_string
(
s
.
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
s
->
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
inputs
.
front
().
elements
()));
std
::
to_string
(
inputs
.
front
().
elements
()));
return
s
;
assert
(
s
->
bytes
()
==
inputs
.
front
().
bytes
());
return
*
s
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
src/include/migraphx/op/run_on_target.hpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_ON_TARGET_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_ON_TARGET_HPP
#include <unordered_map>
#include <vector>
#include <set>
#include <algorithm>
#include <migraphx/config.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/module.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
run_on_target
{
std
::
size_t
target_id
=
0
;
std
::
string
name
()
const
{
return
"run_on_target"
;
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
target_id
,
"target_id"
));
}
migraphx
::
shape
compute_shape
(
const
std
::
vector
<
migraphx
::
shape
>&
inputs
,
std
::
vector
<
migraphx
::
module_ref
>
mods
)
const
{
if
(
mods
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"RUN_ON_TARGET: must have exactly 1 module argument"
);
}
auto
*
mod_input
=
mods
.
front
();
if
(
inputs
.
size
()
!=
mod_input
->
get_parameter_shapes
().
size
())
{
MIGRAPHX_THROW
(
"RUN_ON_TARGET: Mismatched number of input parameters"
);
}
auto
mod_out_shapes
=
mod_input
->
get_output_shapes
();
return
mod_out_shapes
;
}
migraphx
::
argument
compute
(
const
migraphx
::
shape
&
,
const
std
::
vector
<
migraphx
::
argument
>&
args
,
const
std
::
vector
<
migraphx
::
module_ref
>&
mods
,
const
std
::
function
<
std
::
vector
<
migraphx
::
argument
>
(
migraphx
::
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
migraphx
::
argument
>&
)
>&
run
)
const
{
std
::
unordered_map
<
std
::
string
,
migraphx
::
argument
>
params
;
std
::
set
<
std
::
string
>
pnames
;
const
auto
*
smod
=
mods
.
front
();
assert
(
mods
.
size
()
==
1
);
auto
names
=
smod
->
get_parameter_names
();
pnames
.
insert
(
names
.
begin
(),
names
.
end
());
assert
(
pnames
.
size
()
==
args
.
size
());
std
::
transform
(
pnames
.
begin
(),
pnames
.
end
(),
args
.
begin
(),
std
::
inserter
(
params
,
params
.
end
()),
[](
auto
&&
name
,
auto
&&
arg
)
{
return
std
::
make_pair
(
name
,
arg
);
});
auto
*
mod
=
mods
.
front
();
auto
results
=
run
(
mod
,
params
);
return
migraphx
::
argument
{
results
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/select_module.hpp
View file @
cd4ab535
...
@@ -57,6 +57,7 @@ struct select_module
...
@@ -57,6 +57,7 @@ struct select_module
param_names
.
cend
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
not
contains
(
pn
,
"#output_"
);
});
[](
auto
pn
)
{
return
not
contains
(
pn
,
"#output_"
);
});
std
::
sort
(
ret
.
begin
(),
ret
.
end
());
return
ret
;
return
ret
;
}
}
...
@@ -68,6 +69,8 @@ struct select_module
...
@@ -68,6 +69,8 @@ struct select_module
param_names
.
cend
(),
param_names
.
cend
(),
std
::
back_inserter
(
ret
),
std
::
back_inserter
(
ret
),
[](
auto
pn
)
{
return
contains
(
pn
,
"#output_"
);
});
[](
auto
pn
)
{
return
contains
(
pn
,
"#output_"
);
});
// needs to be sorted to ensure output parameter ordering
std
::
sort
(
ret
.
begin
(),
ret
.
end
());
return
ret
;
return
ret
;
}
}
...
@@ -111,6 +114,7 @@ struct select_module
...
@@ -111,6 +114,7 @@ struct select_module
// One tuple output parameter in main module to multiple output parameters in submodule
// One tuple output parameter in main module to multiple output parameters in submodule
auto
out_param_names
=
get_output_parameter_names
(
module_to_run
);
auto
out_param_names
=
get_output_parameter_names
(
module_to_run
);
auto
param_shapes
=
module_to_run
->
get_parameter_shapes
();
auto
output_sub_objects
=
args
.
back
().
get_sub_objects
();
auto
output_sub_objects
=
args
.
back
().
get_sub_objects
();
assert
(
out_param_names
.
size
()
==
output_sub_objects
.
size
());
assert
(
out_param_names
.
size
()
==
output_sub_objects
.
size
());
std
::
transform
(
out_param_names
.
begin
(),
std
::
transform
(
out_param_names
.
begin
(),
...
@@ -118,10 +122,10 @@ struct select_module
...
@@ -118,10 +122,10 @@ struct select_module
output_sub_objects
.
begin
(),
output_sub_objects
.
begin
(),
std
::
inserter
(
p_map
,
p_map
.
end
()),
std
::
inserter
(
p_map
,
p_map
.
end
()),
[
&
](
auto
&&
name
,
auto
&&
a
)
{
[
&
](
auto
&&
name
,
auto
&&
a
)
{
auto
ps
=
module_to_run
->
get_
param
eter
_shape
(
name
);
auto
ps
=
param_shape
s
.
at
(
name
);
if
(
a
.
get_shape
()
!=
ps
)
if
(
a
.
get_shape
()
!=
ps
)
{
{
assert
(
ps
.
bytes
()
=
=
a
.
get_shape
().
bytes
());
assert
(
ps
.
bytes
()
<
=
a
.
get_shape
().
bytes
());
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
return
std
::
make_pair
(
name
,
a
.
reshape
(
ps
));
}
}
else
else
...
...
src/include/migraphx/op/slice.hpp
View file @
cd4ab535
...
@@ -111,16 +111,15 @@ struct slice
...
@@ -111,16 +111,15 @@ struct slice
// For a static shape, old_lens will be adjusted to a new size
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
// while updating the old mins and opt
imal
s if possible.
std
::
vector
<
std
::
size_t
>
new_mins
;
std
::
vector
<
std
::
size_t
>
new_mins
;
std
::
vector
<
std
::
size_t
>
new_opts
;
std
::
vector
<
std
::
size_t
>
old_lens
;
std
::
vector
<
std
::
size_t
>
old_lens
;
std
::
vector
<
std
::
size_t
>
old_strides
;
std
::
vector
<
std
::
size_t
>
old_strides
;
// Doesn't handle optimals
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
old_lens
=
input_shape
.
max_lens
();
old_lens
=
input_shape
.
max_lens
();
new_mins
=
input_shape
.
min_lens
();
new_mins
=
input_shape
.
min_lens
();
new_opts
=
input_shape
.
opt_lens
();
}
}
else
else
{
{
...
@@ -146,17 +145,11 @@ struct slice
...
@@ -146,17 +145,11 @@ struct slice
std
::
size_t
sliced_min_length
=
ends
[
i
]
-
starts
[
i
];
std
::
size_t
sliced_min_length
=
ends
[
i
]
-
starts
[
i
];
// if the slice size is smaller than maxes but larger than mins
// if the slice size is smaller than maxes but larger than mins
new_mins
[
axis
]
=
std
::
min
(
sliced_min_length
,
new_mins
[
axis
]);
new_mins
[
axis
]
=
std
::
min
(
sliced_min_length
,
new_mins
[
axis
]);
auto
sliced_opt_length
=
ends
[
i
]
-
starts
[
i
];
if
(
new_opts
[
axis
]
!=
0
)
new_opts
[
axis
]
=
sliced_opt_length
;
if
(
new_opts
[
axis
]
<
new_mins
[
axis
]
or
new_opts
[
axis
]
>
new_lens
[
axis
])
new_opts
[
axis
]
=
0
;
}
}
}
}
if
(
input_shape
.
dynamic
())
if
(
input_shape
.
dynamic
())
{
{
return
shape
{
t
,
new_mins
,
new_lens
,
new_opts
};
return
shape
{
t
,
new_mins
,
new_lens
,
{}
};
}
}
else
else
{
{
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
cd4ab535
...
@@ -81,7 +81,7 @@ struct unsqueeze
...
@@ -81,7 +81,7 @@ struct unsqueeze
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
{
dyn_dims
.
push_back
({
1
,
1
,
0
});
dyn_dims
.
push_back
({
1
,
1
});
}
}
else
else
{
{
...
@@ -95,13 +95,10 @@ struct unsqueeze
...
@@ -95,13 +95,10 @@ struct unsqueeze
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
auto
is_scalar
=
input_shape
.
scalar
();
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
if
(
is_scalar
and
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
return
shape
{
type
,
old_lens
};
return
shape
{
type
,
old_lens
};
else
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
if
(
steps
.
size
()
>
axes
.
size
())
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
...
@@ -121,13 +118,15 @@ struct unsqueeze
...
@@ -121,13 +118,15 @@ struct unsqueeze
step
=
steps
[
axis_idx
];
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
if
(
is_scalar
and
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be 1 when input is scalar"
);
new_lens
[
i
]
=
step
;
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
if
(
p
<
old_strides
.
size
())
{
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
new_strides
[
i
]
=
is_scalar
?
1
:
old_strides
[
p
]
*
old_lens
[
p
];
}
}
else
else
{
{
...
...
src/include/migraphx/pass_manager.hpp
View file @
cd4ab535
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
#include <vector>
...
@@ -39,12 +40,17 @@ struct module_pass_manager
...
@@ -39,12 +40,17 @@ struct module_pass_manager
virtual
module
&
get_module
()
=
0
;
virtual
module
&
get_module
()
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
get_common_parent
()
=
0
;
virtual
module
*
get_common_parent
()
=
0
;
virtual
module
*
get_root_module
()
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
protected:
protected:
virtual
~
module_pass_manager
()
{}
virtual
~
module_pass_manager
()
{}
};
};
void
run_passes
(
program
&
prog
,
module_ref
root_mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
module
&
mod
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
void
run_passes
(
program
&
prog
,
const
std
::
vector
<
pass
>&
passes
,
tracer
trace
=
tracer
{});
...
...
src/include/migraphx/permutation.hpp
View file @
cd4ab535
...
@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
...
@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
}
}
/*!
/*!
* Returns the permutation
needed to apply to the shape
to undo the
current
permutation
* Returns the
inverse
permutation
that could be applied
to undo the
inputted
permutation
*/
*/
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
);
/*!
/*!
* Finds the permutation
most likely from a transpose operator that has been applied to the shape.
* Finds the permutation
that would make the shape not transposed (refering to shape.transposed())
*/
*/
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
...
...
src/include/migraphx/process.hpp
View file @
cd4ab535
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <functional>
#include <string>
#include <string>
#include <memory>
#include <memory>
...
@@ -36,6 +37,7 @@ struct process_impl;
...
@@ -36,6 +37,7 @@ struct process_impl;
struct
process
struct
process
{
{
using
writer
=
std
::
function
<
void
(
const
char
*
,
std
::
size_t
)
>
;
process
(
const
std
::
string
&
cmd
);
process
(
const
std
::
string
&
cmd
);
// move constructor
// move constructor
...
@@ -49,6 +51,7 @@ struct process
...
@@ -49,6 +51,7 @@ struct process
process
&
cwd
(
const
fs
::
path
&
p
);
process
&
cwd
(
const
fs
::
path
&
p
);
void
exec
();
void
exec
();
void
write
(
std
::
function
<
void
(
process
::
writer
)
>
pipe_in
);
private:
private:
std
::
unique_ptr
<
process_impl
>
impl
;
std
::
unique_ptr
<
process_impl
>
impl
;
...
...
src/include/migraphx/program.hpp
View file @
cd4ab535
...
@@ -92,6 +92,9 @@ struct program
...
@@ -92,6 +92,9 @@ struct program
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
void
compile
(
const
target
&
t
,
compile_options
options
=
compile_options
{});
void
compile
(
const
std
::
vector
<
target
>&
targets
,
std
::
vector
<
compile_options
>
compile_opts
=
{});
bool
is_compiled
()
const
;
bool
is_compiled
()
const
;
void
finalize
();
void
finalize
();
...
...
src/include/migraphx/promote_literals.hpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PROMOTE_LITERALS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PROMOTE_LITERALS_HPP
#include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Replace literals in submodules with literals in the root module.
* Intended to allow for reuse of the literals between submodules.
*/
struct
promote_literals
{
std
::
string
name
()
const
{
return
"promote_literals"
;
}
void
apply
(
module_pass_manager
&
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/reflect.hpp
View file @
cd4ab535
...
@@ -78,7 +78,7 @@ template <class T>
...
@@ -78,7 +78,7 @@ template <class T>
struct
wrapper
struct
wrapper
{
{
using
type
=
typename
remove_rvalue_reference
<
T
>::
type
;
using
type
=
typename
remove_rvalue_reference
<
T
>::
type
;
type
data
;
type
data
;
// NOLINT
type
get
()
const
{
return
data
;
}
type
get
()
const
{
return
data
;
}
};
};
...
...
src/include/migraphx/serialize.hpp
View file @
cd4ab535
...
@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
...
@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
}
}
template
<
class
T
>
template
<
class
T
>
auto
from_value_impl
(
rank
<
4
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
*
x
.
begin
()),
void
())
auto
from_value_impl
(
rank
<
4
>
,
const
value
&
v
,
T
&
x
)
->
decltype
(
x
.
insert
(
*
x
.
begin
()),
std
::
declval
<
typename
T
::
mapped_type
>
(),
void
())
{
{
x
.
clear
();
x
.
clear
();
for
(
auto
&&
e
:
v
)
for
(
auto
&&
e
:
v
)
...
...
src/include/migraphx/shape.hpp
View file @
cd4ab535
...
@@ -29,10 +29,12 @@
...
@@ -29,10 +29,12 @@
#include <ostream>
#include <ostream>
#include <numeric>
#include <numeric>
#include <memory>
#include <memory>
#include <set>
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -87,12 +89,12 @@ struct shape
...
@@ -87,12 +89,12 @@ struct shape
{
{
std
::
size_t
min
=
0
;
std
::
size_t
min
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
max
=
0
;
std
::
size_t
opt
=
0
;
std
::
set
<
std
::
size_t
>
opt
imals
{}
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
imals
,
"opt
imals
"
));
}
}
bool
is_fixed
()
const
;
bool
is_fixed
()
const
;
...
@@ -132,11 +134,12 @@ struct shape
...
@@ -132,11 +134,12 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
shape
(
type_t
t
,
std
::
vector
<
dynamic_dimension
>
dims
);
// Construct a dynamic shape from three sets of lengths (of the same rank)
// Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
shape
(
type_t
t
,
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
mins
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
maxes
,
std
::
vector
<
std
::
size_t
>
opt
s
);
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt
imals_list
);
template
<
class
Range
>
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
...
@@ -153,14 +156,34 @@ struct shape
...
@@ -153,14 +156,34 @@ struct shape
shape
(
const
std
::
vector
<
shape
>&
subs
);
shape
(
const
std
::
vector
<
shape
>&
subs
);
/**
* Creates an output shape with dimensions equal to the input lengths and strides determined
* by the permutation argument such that find_permutation() of the output shape returns the
* inputted permuation.
*
* 2D example:
* parameters:
* l = [2, 3], perm = [1, 0]
* therefore:
* "original" shape = {lens = [3, 2], strides = [2, 1]}
* output_shape = {lens = [2, 3], strides = [1, 2]
*
* 3D example:
* parameters:
* l = [2, 3, 4], perm = [1, 2, 0]
* therefore:
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
*/
static
shape
static
shape
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
);
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
);
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
/*!
/*!
* The number of dimensions in the shape.
* The number of dimensions in the shape
, either static or dynamic
.
* Same as the number of indices required to get a data value.
* Same as the number of indices required to get a data value.
*/
*/
std
::
size_t
ndim
()
const
;
std
::
size_t
ndim
()
const
;
...
@@ -186,21 +209,21 @@ struct shape
...
@@ -186,21 +209,21 @@ struct shape
/*!
/*!
* Minimum lengths for dynamic shape.
* Minimum lengths for dynamic shape.
* lens() for
fixed
shape.
* lens() for
static
shape.
*/
*/
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
std
::
vector
<
std
::
size_t
>
min_lens
()
const
;
/*!
/*!
* Maximum lengths for dynamic shape.
* Maximum lengths for dynamic shape.
* lens() for
fixed
shape.
* lens() for
static
shape.
*/
*/
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
std
::
vector
<
std
::
size_t
>
max_lens
()
const
;
/*!
/*!
* Optimum lengths for dynamic shape.
* Optimum lengths for dynamic shape.
*
lens() for fixed
shape.
*
Empty for static
shape.
*/
*/
std
::
vector
<
std
::
size_t
>
opt_lens
()
const
;
std
::
vector
<
std
::
set
<
std
::
size_t
>
>
opt_lens
()
const
;
/// Map multiple indices to space index
/// Map multiple indices to space index
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
std
::
size_t
index
(
std
::
initializer_list
<
std
::
size_t
>
l
)
const
;
...
@@ -219,11 +242,15 @@ struct shape
...
@@ -219,11 +242,15 @@ struct shape
/// Map element index to space index
/// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
/// Map element index to multi-dimensional index
void
multi_copy
(
std
::
size_t
i
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
idx
)
const
;
/// Map element index to multi-dimensional index and put them them into location provided by
/// pointers
void
multi_copy
(
std
::
size_t
idx
,
std
::
size_t
*
start
,
const
std
::
size_t
*
end
)
const
;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
no
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// padding
///
no
padding
bool
packed
()
const
;
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// Returns true is the shape has been transposed. That is the strides are not in descending
...
@@ -253,9 +280,12 @@ struct shape
...
@@ -253,9 +280,12 @@ struct shape
shape
with_type
(
type_t
t
)
const
;
shape
with_type
(
type_t
t
)
const
;
// convert the shape to an equivalent dynamic shape
// convert the shape to an equivalent dynamic shape
with empty optimals
shape
to_dynamic
()
const
;
shape
to_dynamic
()
const
;
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape
to_static
(
std
::
size_t
x
)
const
;
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
==
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
...
@@ -269,6 +299,8 @@ struct shape
...
@@ -269,6 +299,8 @@ struct shape
type
min
()
const
{
return
std
::
numeric_limits
<
type
>::
lowest
();
}
type
min
()
const
{
return
std
::
numeric_limits
<
type
>::
lowest
();
}
type
nan
()
const
{
return
std
::
numeric_limits
<
type
>::
quiet_NaN
();
}
template
<
class
U
>
template
<
class
U
>
type
operator
()(
U
u
)
const
type
operator
()(
U
u
)
const
{
{
...
...
src/include/migraphx/split_single_dyn_dim.hpp
0 → 100644
View file @
cd4ab535
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <string>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic.
*/
struct
split_single_dyn_dim
{
std
::
string
name
()
const
{
return
"split_single_dyn_dim"
;
}
void
apply
(
module_pass_manager
&
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tf.hpp
View file @
cd4ab535
...
@@ -43,6 +43,8 @@ struct tf_options
...
@@ -43,6 +43,8 @@ struct tf_options
/// Create a program from a tf pb file (default is nhwc format)
/// Create a program from a tf pb file (default is nhwc format)
program
parse_tf
(
const
std
::
string
&
name
,
const
tf_options
&
options
=
tf_options
{});
program
parse_tf
(
const
std
::
string
&
name
,
const
tf_options
&
options
=
tf_options
{});
std
::
vector
<
std
::
string
>
get_tf_operators
();
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/value.hpp
View file @
cd4ab535
...
@@ -392,8 +392,8 @@ struct value
...
@@ -392,8 +392,8 @@ struct value
return; \
return; \
}
}
MIGRAPHX_VISIT_VALUE_TYPES
(
MIGRAPHX_VALUE_GENERATE_CASE_VALUE
)
MIGRAPHX_VISIT_VALUE_TYPES
(
MIGRAPHX_VALUE_GENERATE_CASE_VALUE
)
MIGRAPHX_VALUE_GENERATE_CASE
(
array
,
)
MIGRAPHX_VALUE_GENERATE_CASE
_VALUE
(
array
,
)
MIGRAPHX_VALUE_GENERATE_CASE
(
object
,
)
MIGRAPHX_VALUE_GENERATE_CASE
_VALUE
(
object
,
)
}
}
MIGRAPHX_THROW
(
"Unknown type"
);
MIGRAPHX_THROW
(
"Unknown type"
);
}
}
...
@@ -461,6 +461,8 @@ struct value
...
@@ -461,6 +461,8 @@ struct value
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
value
&
d
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
value
&
d
);
std
::
size_t
hash
()
const
;
void
debug_print
(
bool
show_type
=
false
)
const
;
void
debug_print
(
bool
show_type
=
false
)
const
;
type_t
get_type
()
const
;
type_t
get_type
()
const
;
...
@@ -481,4 +483,15 @@ struct value
...
@@ -481,4 +483,15 @@ struct value
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
namespace
std
{
template
<
>
struct
hash
<
migraphx
::
value
>
{
using
argument_type
=
migraphx
::
value
;
using
result_type
=
std
::
size_t
;
result_type
operator
()(
const
migraphx
::
value
&
x
)
const
{
return
x
.
hash
();
}
};
}
// namespace std
#endif
#endif
src/instruction.cpp
100755 → 100644
View file @
cd4ab535
...
@@ -406,6 +406,9 @@ void instruction::print(std::ostream& os,
...
@@ -406,6 +406,9 @@ void instruction::print(std::ostream& os,
// skip return instruction shape
// skip return instruction shape
if
(
ins
->
name
()
!=
"@return"
)
if
(
ins
->
name
()
!=
"@return"
)
os
<<
" -> "
<<
ins
->
get_shape
();
os
<<
" -> "
<<
ins
->
get_shape
();
// print tid
os
<<
", target_id="
<<
ins
->
target_id
;
}
}
static
void
debug_name
(
std
::
ostream
&
os
,
const
instruction
&
ins
)
static
void
debug_name
(
std
::
ostream
&
os
,
const
instruction
&
ins
)
...
@@ -469,7 +472,8 @@ operation instruction::normalized_operator() const
...
@@ -469,7 +472,8 @@ operation instruction::normalized_operator() const
}
}
return
o
;
return
o
;
}
}
std
::
size_t
instruction
::
get_target_id
()
const
{
return
target_id
;
}
void
instruction
::
set_target_id
(
std
::
size_t
tid
)
{
this
->
target_id
=
tid
;
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
vector
<
shape
>
shapes
(
args
.
size
());
...
...
src/module.cpp
View file @
cd4ab535
...
@@ -595,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
...
@@ -595,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
}
}
}
}
std
::
vector
<
instruction_ref
>
module
::
get_returns
()
const
{
auto
last
=
std
::
prev
(
this
->
end
());
if
(
last
->
name
()
==
"@return"
)
return
last
->
inputs
();
return
{
last
};
}
instruction_ref
module
::
validate
()
const
instruction_ref
module
::
validate
()
const
{
{
return
std
::
find_if
(
return
std
::
find_if
(
...
@@ -715,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
...
@@ -715,15 +723,15 @@ std::unordered_map<instruction_ref, std::string> module::print(
for
(
auto
ins
:
iterator_for
(
*
this
))
for
(
auto
ins
:
iterator_for
(
*
this
))
{
{
std
::
string
var_name
;
std
::
string
var_name
;
if
(
not
this
->
name
().
empty
()
and
this
->
name
()
!=
"main"
)
var_name
=
this
->
name
()
+
":"
;
if
(
ins
->
name
()
==
"@param"
)
if
(
ins
->
name
()
==
"@param"
)
{
{
var_name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
var_name
.
append
(
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
)
;
}
}
else
else
{
{
var_name
=
this
->
name
();
var_name
.
append
(
"@"
+
std
::
to_string
(
count
));
var_name
.
append
((
this
->
name
().
empty
()
?
"@"
:
":@"
));
var_name
.
append
(
std
::
to_string
(
count
));
}
}
// count every instruction so index matches loc in the printout program
// count every instruction so index matches loc in the printout program
count
++
;
count
++
;
...
@@ -787,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
...
@@ -787,7 +795,10 @@ static std::string to_c_id(const std::string& name, char rep = '_')
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
{
{
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
std
::
string
prefix
=
"x_"
;
if
(
not
contains
(
name
,
"@"
))
prefix
=
"p_"
;
return
to_c_id
(
prefix
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
@@ -867,7 +878,7 @@ module::print_py(std::ostream& os,
...
@@ -867,7 +878,7 @@ module::print_py(std::ostream& os,
use_abs
=
false
;
use_abs
=
false
;
if
(
use_abs
)
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_
literal
("
;
os
<<
"migraphx.generate_
argument
("
;
print_py_shape
(
os
,
ins
->
get_shape
());
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
if
(
use_abs
)
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment