Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f8a75f8a
Commit
f8a75f8a
authored
Dec 07, 2023
by
Paul
Browse files
Merge
parents
74448ed6
d00fdf6e
Changes
242
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1037 additions
and
163 deletions
+1037
-163
src/include/migraphx/op/scatternd_max.hpp
src/include/migraphx/op/scatternd_max.hpp
+47
-0
src/include/migraphx/op/scatternd_min.hpp
src/include/migraphx/op/scatternd_min.hpp
+47
-0
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+1
-0
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+5
-4
src/include/migraphx/op/unique.hpp
src/include/migraphx/op/unique.hpp
+323
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/include/migraphx/par.hpp
src/include/migraphx/par.hpp
+133
-0
src/include/migraphx/par_for.hpp
src/include/migraphx/par_for.hpp
+7
-77
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+3
-1
src/include/migraphx/simple_par_for.hpp
src/include/migraphx/simple_par_for.hpp
+119
-0
src/include/migraphx/tune_axis.hpp
src/include/migraphx/tune_axis.hpp
+8
-8
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+15
-5
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+9
-2
src/onnx/include/migraphx/onnx/pooling.hpp
src/onnx/include/migraphx/onnx/pooling.hpp
+47
-0
src/onnx/onnx.proto
src/onnx/onnx.proto
+193
-61
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+23
-0
src/onnx/parse_lstm.cpp
src/onnx/parse_lstm.cpp
+47
-0
src/onnx/parse_multinomial.cpp
src/onnx/parse_multinomial.cpp
+3
-3
No files found.
src/include/migraphx/op/scatternd_max.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatternd_max
:
scatternd_op
<
scatternd_max
>
{
scatternd_max
()
{}
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
std
::
max
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatternd_min.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatternd_min
:
scatternd_op
<
scatternd_min
>
{
scatternd_min
()
{}
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
std
::
min
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatternd_op.hpp
View file @
f8a75f8a
...
@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
...
@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
auto
k
=
indices_shape
.
lens
().
back
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
ndim
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
dyn_out
.
computed_shape
.
ndim
();
auto
r
=
dyn_out
.
computed_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
for
(
auto
i
=
0u
;
i
<
updates_shape
.
elements
();
++
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
copy
(
std
::
copy
(
...
@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
...
@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
dyn_out
.
computed_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
computed_shape
.
index
(
out_idx
)],
updates
[
i
]);
}
);
}
});
});
});
});
...
...
src/include/migraphx/op/slice.hpp
View file @
f8a75f8a
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <array>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/unary.hpp
View file @
f8a75f8a
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -84,10 +85,10 @@ struct unary : op_name<Derived>
...
@@ -84,10 +85,10 @@ struct unary : op_name<Derived>
argument
result
{
dyn_out
.
computed_shape
};
argument
result
{
dyn_out
.
computed_shape
};
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
par_
transform
(
input
.
begin
(),
input
.
end
(),
input
.
end
(),
output
.
begin
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
});
});
});
});
return
result
;
return
result
;
...
...
src/include/migraphx/op/unique.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tune_axis.hpp>
#include <utility>
#include <map>
#include <limits>
#include <optional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
// https://onnx.ai/onnx/operators/onnx__Unique.html
// The Onnx spec refers to numpy specification, used as a reference:
// https://numpy.org/doc/stable/reference/generated/numpy.unique.html
// Input : Given an array of elements : X.
// Output(s) :
// 1. Find the unique elements (Y) of input (X).
//
// There are three outputs in addition to the unique elements in Y:
// 2. the indices of the input array that give the unique values
// 3. the indices of the unique array that reconstruct the input array
// 4. the number of times each unique value comes up in the input array
// Optional Attribute: 'Sorted' = 1 for sorted; = 0 for unsorted.
// Onnx specification makes 'sorted' a default, while Numpy always sorts.
//
// Optional Attribute: 'Axis' is 'None' (default) or a valid int < rank(X).
// Negative values are allowed.
//
// Numpy has the following important note on Axis:
// ------------------------------------------------------------------
// When an axis is specified the subarrays indexed by the axis are
// sorted. This is done by making the specified axis the first
// dimension of the array (move the axis to the first dimension to
// keep the order of the other axes) and then flattening the subarrays
// in C order. The flattened subarrays are then viewed as a structured
// type with each element given a label, with the effect that we end
// up with a 1-D array of structured types that can be treated in the
// same way as any other 1-D array. The result is that the flattened
// subarrays are sorted in lexicographic order starting with the first
// element.
// ------------------------------------------------------------------
struct
unique
{
template
<
class
T
>
auto
make_idx_less_fn
(
const
T
&
data
,
size_t
chunk_sz
)
const
{
return
[
&
data
,
chunk_sz
](
auto
idx1
,
auto
idx2
)
{
return
std
::
lexicographical_compare
(
data
.
begin
()
+
idx1
,
data
.
begin
()
+
idx1
+
chunk_sz
,
data
.
begin
()
+
idx2
,
data
.
begin
()
+
idx2
+
chunk_sz
);
};
}
// CASE SORTED:
//
// To process into a sorted unique series of elements/chunks:
// Chunk size == 1 means a simple element; >1 means a flat representation.
// Steps: first go through the input elements/chunks for uniqueness.
// At the end of this processing, per the sorted sequence of unique elements:
// update/create data structures: y, y_indices, x_rev_indices, y_count
//
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// OUTPUT(s): indices..
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [1, 2, 3, 4] --- the unique output is constructed from x[y_indices[...]]
template
<
class
T
>
auto
sorted_uniq_indices
(
const
T
&
input_data
,
size_t
chunk_sz
)
const
{
struct
y_info
{
size_t
y_idx
;
size_t
x_idx
;
size_t
ct
=
0
;
};
auto
idx_less_fn
=
make_idx_less_fn
(
input_data
,
chunk_sz
);
std
::
map
<
size_t
,
y_info
,
decltype
(
idx_less_fn
)
>
uniq_val_map
(
idx_less_fn
);
std
::
tuple
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>>
rv
;
auto
&
[
y_indices
,
x_rev_indices
,
y_count
]
=
rv
;
// go through all the elements and find the unique elements..
size_t
count_x
=
input_data
.
size
();
for
(
size_t
f_idx
=
0
,
x_idx
=
0
;
f_idx
<
count_x
;
f_idx
+=
chunk_sz
,
x_idx
++
)
{
y_info
entry
=
{.
y_idx
=
uniq_val_map
.
size
(),
.
x_idx
=
x_idx
};
auto
[
itr
,
added_new
]
=
uniq_val_map
.
insert
({
f_idx
,
entry
});
itr
->
second
.
ct
++
;
x_rev_indices
.
push_back
(
itr
->
second
.
y_idx
);
}
std
::
vector
<
std
::
size_t
>
y2x_indices
(
uniq_val_map
.
size
());
y_indices
.
resize
(
uniq_val_map
.
size
());
y_count
.
resize
(
uniq_val_map
.
size
());
size_t
idx
=
0
;
// the unique elements are now sorted:
// post-processing for all the return indices.
for
(
const
auto
&
v
:
uniq_val_map
)
{
y2x_indices
[
v
.
second
.
y_idx
]
=
idx
;
y_indices
[
idx
]
=
v
.
second
.
x_idx
;
y_count
[
idx
]
=
v
.
second
.
ct
;
idx
++
;
}
// update x_rev_indices as per the sorted order of y_indices
for
(
auto
&
i
:
x_rev_indices
)
i
=
y2x_indices
[
i
];
return
rv
;
}
// CASE UNSORTED:
//
// To process into an un-sorted unique series of elements/chunks:
// For chunk size = 1 is a simple element, else use a flat representation of a tensor obj
// Go through the input elements/chunks one by one with inline processing of indices..
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// OUTPUT(s): indices..
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [2, 1, 3, 4] --- the unique output is constructed from x[y_indices[...]]
// Output data structures: y_indices, x_rev_indices, y_count are processed inline.
template
<
class
T
>
auto
unsorted_uniq_indices
(
const
T
&
input_data
,
size_t
chunk_sz
)
const
{
auto
idx_less_fn
=
make_idx_less_fn
(
input_data
,
chunk_sz
);
std
::
map
<
size_t
,
size_t
,
decltype
(
idx_less_fn
)
>
uniq_val_map
(
idx_less_fn
);
// rv is used for NVRO below..
std
::
tuple
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>>
rv
;
auto
&
[
y_indices
,
x_rev_indices
,
y_count
]
=
rv
;
// go through all the elements and add the unique elements into the map..
// inline processing for outputs: y_indices, x_rev_indices, y_count
size_t
count_x
=
input_data
.
size
();
for
(
size_t
f_idx
=
0
;
f_idx
<
count_x
;
f_idx
+=
chunk_sz
)
{
auto
[
itr
,
added_new
]
=
uniq_val_map
.
insert
({
f_idx
,
y_indices
.
size
()});
if
(
added_new
)
{
y_count
.
push_back
(
0
);
y_indices
.
push_back
(
x_rev_indices
.
size
());
}
y_count
[
itr
->
second
]
++
;
x_rev_indices
.
push_back
(
itr
->
second
);
}
return
rv
;
}
// Axis. Default: none. Range: [-rank, rank-1]
std
::
optional
<
int64_t
>
axis
;
// Sorted, Default: 1= sorted. 0 = unsorted.
bool
sorted
=
true
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
sorted
,
"sorted"
));
}
std
::
string
name
()
const
{
return
"unique"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&
sh_x
=
inputs
[
0
];
auto
lens_x
=
sh_x
.
lens
();
size_t
dim_x
=
sh_x
.
ndim
();
size_t
max_uniq_ct
=
sh_x
.
elements
();
std
::
vector
<
shape
::
dynamic_dimension
>
d_out
;
if
(
axis
)
{
int64_t
t_axis
=
migraphx
::
tune_axis
(
dim_x
,
*
axis
,
name
());
if
(
t_axis
!=
0
)
MIGRAPHX_THROW
(
"Unique: Only supports axis = 0 or None"
);
d_out
=
sh_x
.
to_dynamic
().
dyn_dims
();
// only axis = 0 is supported:
max_uniq_ct
=
lens_x
[
0
];
// min = 1 unique element; max = full dimension along axis 0
d_out
[
0
]
=
{
1
,
max_uniq_ct
};
}
else
{
d_out
.
push_back
({
1
,
max_uniq_ct
});
}
shape
sh_y
=
{
sh_x
.
type
(),
d_out
};
// The three outputted Indices are just 1-D:
shape
sh_idx
{
shape
::
int64_type
,
{
d_out
[
0
]}};
return
{{
sh_y
,
sh_idx
,
sh_idx
,
sh_idx
}};
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
auto
sh_x
=
args
.
front
().
get_shape
();
auto
lens_x
=
sh_x
.
lens
();
shape
output_shape
=
dyn_out
.
computed_shape
;
auto
vec_ss
=
output_shape
.
sub_shapes
();
auto
ct_x
=
sh_x
.
elements
();
shape
sh_y
=
{
vec_ss
[
0
].
type
(),
{
ct_x
}};
shape
sh_idx
=
{
vec_ss
[
1
].
type
(),
{
ct_x
}};
shape
sh_x_idx
=
{
vec_ss
[
1
].
type
(),
{
ct_x
}};
argument
res_y
{
sh_y
};
argument
res_y_idx
{
sh_idx
};
argument
res_x_rev_idx
{
sh_idx
};
argument
res_y_ct_idx
{
sh_idx
};
std
::
vector
<
size_t
>
out_y_idx
;
std
::
vector
<
size_t
>
out_x_rev_idx
;
std
::
vector
<
size_t
>
out_y_ct
;
// If axis is not none, for >1D tensors, we have to consider
// then, the uniqueness of chunks of sub-tensors: a subsequence of built-ins..
// For a built-in type, chunk_sz is of course = 1
size_t
chunk_sz
=
1
;
if
(
axis
)
chunk_sz
=
ct_x
/
lens_x
[
0
];
// axis = 0 is supported.
visit_all
(
args
.
front
(),
res_y
)([
&
](
auto
x
,
auto
y_flat
)
{
using
o_type
=
typename
decltype
(
x
)
::
value_type
;
std
::
vector
<
o_type
>
x_in
(
x
.
begin
(),
x
.
end
());
std
::
tie
(
out_y_idx
,
out_x_rev_idx
,
out_y_ct
)
=
sorted
?
sorted_uniq_indices
(
x_in
,
chunk_sz
)
:
unsorted_uniq_indices
(
x_in
,
chunk_sz
);
const
auto
uniq_ct
=
out_y_idx
.
size
();
// construct y from x[indices] in flattened form
// later we reshape y to the final shape..
auto
y_dst
=
y_flat
.
begin
();
for
(
size_t
idx
=
0
;
idx
<
uniq_ct
;
idx
++
)
y_dst
=
copy_n
(
x_in
.
begin
()
+
out_y_idx
[
idx
]
*
chunk_sz
,
chunk_sz
,
y_dst
);
std
::
vector
<
size_t
>
lens_y
;
// if axis is specified:
// the output shape keeps the n-1 dimensions of x
if
(
axis
)
{
lens_y
=
lens_x
;
lens_y
[
0
]
=
uniq_ct
;
}
else
{
lens_y
=
{
uniq_ct
};
}
sh_y
=
{
sh_y
.
type
(),
lens_y
};
sh_idx
=
{
sh_idx
.
type
(),
{
uniq_ct
}};
});
visit_all
(
res_y_idx
,
res_x_rev_idx
,
res_y_ct_idx
)(
[
&
](
auto
y_indices
,
auto
x_rev_indices
,
auto
y_count
)
{
std
::
copy
(
out_y_idx
.
begin
(),
out_y_idx
.
end
(),
y_indices
.
begin
());
std
::
copy
(
out_x_rev_idx
.
begin
(),
out_x_rev_idx
.
end
(),
x_rev_indices
.
begin
());
std
::
copy
(
out_y_ct
.
begin
(),
out_y_ct
.
end
(),
y_count
.
begin
());
sh_x_idx
=
{
sh_idx
.
type
(),
{
out_x_rev_idx
.
size
()}};
});
return
{{
res_y
.
reshape
(
sh_y
),
res_y_idx
.
reshape
(
sh_idx
),
res_x_rev_idx
.
reshape
(
sh_x_idx
),
res_y_ct_idx
.
reshape
(
sh_idx
)}};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
f8a75f8a
...
@@ -119,6 +119,8 @@
...
@@ -119,6 +119,8 @@
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_max.hpp>
#include <migraphx/op/scatternd_min.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
#include <migraphx/op/sinh.hpp>
...
@@ -137,6 +139,7 @@
...
@@ -137,6 +139,7 @@
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unique.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#include <migraphx/op/where.hpp>
...
...
src/include/migraphx/par.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
struct
exception_list
{
std
::
vector
<
std
::
exception_ptr
>
exceptions
;
std
::
mutex
m
;
void
add_exception
()
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
m
);
exceptions
.
push_back
(
std
::
current_exception
());
}
template
<
class
F
>
auto
collect
(
F
f
)
{
return
[
f
,
this
](
auto
&&
...
xs
)
{
try
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
}
catch
(...)
{
this
->
add_exception
();
}
};
}
void
throw_if_exception
()
const
{
if
(
not
exceptions
.
empty
())
std
::
rethrow_exception
(
exceptions
.
front
());
}
};
}
// namespace detail
template
<
class
InputIt
,
class
OutputIt
,
class
UnaryOperation
>
OutputIt
par_transform
(
InputIt
first1
,
InputIt
last1
,
OutputIt
d_first
,
UnaryOperation
unary_op
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
transform
(
std
::
execution
::
par
,
first1
,
last1
,
d_first
,
std
::
move
(
unary_op
));
#else
simple_par_for
(
last1
-
first1
,
[
&
](
auto
i
)
{
d_first
[
i
]
=
unary_op
(
first1
[
i
]);
});
return
d_first
+
(
last1
-
first1
);
#endif
}
template
<
class
InputIt1
,
class
InputIt2
,
class
OutputIt
,
class
BinaryOperation
>
OutputIt
par_transform
(
InputIt1
first1
,
InputIt1
last1
,
InputIt2
first2
,
OutputIt
d_first
,
BinaryOperation
binary_op
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
transform
(
std
::
execution
::
par
,
first1
,
last1
,
first2
,
d_first
,
std
::
move
(
binary_op
));
#else
simple_par_for
(
last1
-
first1
,
[
&
](
auto
i
)
{
d_first
[
i
]
=
binary_op
(
first1
[
i
],
first2
[
i
]);
});
return
d_first
+
(
last1
-
first1
);
#endif
}
template
<
class
InputIt
,
class
UnaryFunction
>
void
par_for_each
(
InputIt
first
,
InputIt
last
,
UnaryFunction
f
)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail
::
exception_list
ex
;
std
::
for_each
(
std
::
execution
::
par
,
first
,
last
,
ex
.
collect
(
std
::
move
(
f
)));
ex
.
throw_if_exception
();
#else
simple_par_for
(
last
-
first
,
[
&
](
auto
i
)
{
f
(
first
[
i
]);
});
#endif
}
template
<
class
...
Ts
>
auto
par_copy_if
(
Ts
&&
...
xs
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
copy_if
(
std
::
execution
::
par
,
std
::
forward
<
Ts
>
(
xs
)...);
#else
return
std
::
copy_if
(
std
::
forward
<
Ts
>
(
xs
)...);
#endif
}
template
<
class
...
Ts
>
auto
par_sort
(
Ts
&&
...
xs
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
sort
(
std
::
execution
::
par
,
std
::
forward
<
Ts
>
(
xs
)...);
#else
return
std
::
sort
(
std
::
forward
<
Ts
>
(
xs
)...);
#endif
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
src/include/migraphx/par_for.hpp
View file @
f8a75f8a
...
@@ -24,93 +24,23 @@
...
@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <migraphx/par.hpp>
#include <cmath>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
tid
,
F
f
)
->
decltype
(
f
(
i
,
tid
))
{
f
(
i
,
tid
);
}
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
,
F
f
)
->
decltype
(
f
(
i
))
{
f
(
i
);
}
template
<
class
F
>
void
par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
thread_invoke
(
i
,
0
,
f
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
size_t
tid
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
,
&
tid
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
thread_invoke
(
i
,
tid
,
f
);
}
});
work
+=
grainsize
;
++
tid
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
void
par_for
(
std
::
size_t
n
,
F
f
)
{
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
using
iterator
=
basic_iota_iterator
<
id
,
std
::
size_t
>
;
n
/
std
::
max
<
std
::
size_t
>
(
1
,
min_grain
));
par_for_each
(
iterator
{
0
,
{}},
iterator
{
n
,
{}},
f
);
par_for_impl
(
n
,
threadsize
,
f
);
}
}
template
<
class
F
>
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
F
f
)
void
par_for
(
std
::
size_t
n
,
std
::
size_t
,
F
f
)
{
{
const
int
min_grain
=
8
;
par_for
(
n
,
f
);
par_for
(
n
,
min_grain
,
f
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_pooling.hpp
View file @
f8a75f8a
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <string>
#include <string>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/shape.hpp
View file @
f8a75f8a
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
...
@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
...
...
src/include/migraphx/simple_par_for.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
tid
,
F
f
)
->
decltype
(
f
(
i
,
tid
))
{
f
(
i
,
tid
);
}
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
,
F
f
)
->
decltype
(
f
(
i
))
{
f
(
i
);
}
template
<
class
F
>
void
simple_par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
thread_invoke
(
i
,
0
,
f
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
size_t
tid
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
,
&
tid
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
thread_invoke
(
i
,
tid
,
f
);
}
});
work
+=
grainsize
;
++
tid
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
void
simple_par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
std
::
max
<
std
::
size_t
>
(
1
,
min_grain
));
simple_par_for_impl
(
n
,
threadsize
,
f
);
}
template
<
class
F
>
void
simple_par_for
(
std
::
size_t
n
,
F
f
)
{
const
int
min_grain
=
8
;
simple_par_for
(
n
,
min_grain
,
f
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tune_axis.hpp
View file @
f8a75f8a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -24,21 +24,21 @@
...
@@ -24,21 +24,21 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
inline
int
tune_axis
(
int
n_dim
,
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
{
{
if
(
axis
>=
n_dim
or
std
::
abs
(
axis
)
>
n_dim
)
if
(
axis
<
0
)
{
axis
+=
n_dim
;
if
(
axis
<
0
or
axis
>=
n_dim
)
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
}
return
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
return
axis
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/type_traits.hpp
View file @
f8a75f8a
...
@@ -28,25 +28,35 @@
...
@@ -28,25 +28,35 @@
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \
template <class X> \
struct trait : std::trait<X> \
{ \
};
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
template <> \
struct trait<T> : std::true_type \
struct trait<T> : std::true_type \
{ \
{ \
};
};
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_floating_point
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
std
::
conditional_t
<
is_floating_point
<
T
>
{},
std
::
conditional_t
<
is_floating_point
<
T
>
{},
...
...
src/onnx/CMakeLists.txt
View file @
f8a75f8a
...
@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
...
@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
protobuf_generate_cpp
(
PROTO_SRCS PROTO_HDRS onnx.proto
)
protobuf_generate_cpp
(
PROTO_SRCS PROTO_HDRS onnx.proto
)
add_library
(
onnx-proto STATIC
${
PROTO_SRCS
}
)
add_library
(
onnx-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
onnx-proto PRIVATE -w
)
if
(
MSVC
)
target_compile_options
(
onnx-proto PRIVATE /w
)
else
()
target_compile_options
(
onnx-proto PRIVATE -w
)
endif
()
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
set_target_properties
(
onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
...
@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
...
@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header
(
migraphx_onnx
)
migraphx_generate_export_header
(
migraphx_onnx
)
rocm_set_soversion
(
migraphx_onnx
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_onnx
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_onnx
)
rocm_clang_tidy_check
(
migraphx_onnx
)
target_link_libraries
(
migraphx_onnx PRIVATE onnx-proto
"-Wl,--exclude-libs,ALL"
)
target_link_libraries
(
migraphx_onnx PRIVATE onnx-proto
)
if
(
NOT WIN32
)
target_link_libraries
(
migraphx_onnx PRIVATE
"-Wl,--exclude-libs,ALL"
)
endif
()
target_link_libraries
(
migraphx_onnx PUBLIC migraphx
)
target_link_libraries
(
migraphx_onnx PUBLIC migraphx
)
rocm_install_targets
(
rocm_install_targets
(
...
...
src/onnx/include/migraphx/onnx/pooling.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
value
handle_pooling_values
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
value
values
);
instruction_ref
add_pooling_op
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
instruction_ref
l0
);
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/onnx/onnx.proto
View file @
f8a75f8a
This diff is collapsed.
Click to expand it.
src/onnx/onnx_parser.cpp
View file @
f8a75f8a
...
@@ -34,7 +34,9 @@
...
@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
SPARSE_TENSOR
:
case
onnx
::
AttributeProto
::
SPARSE_TENSOR
:
case
onnx
::
AttributeProto
::
SPARSE_TENSORS
:
case
onnx
::
AttributeProto
::
SPARSE_TENSORS
:
case
onnx
::
AttributeProto
::
TYPE_PROTOS
:
case
onnx
::
AttributeProto
::
TYPE_PROTO
:
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
MIGRAPHX_THROW
(
"PARSE_VALUE: Invalid attribute type "
+
std
::
to_string
(
attr
.
type
()));
MIGRAPHX_THROW
(
"PARSE_VALUE: Invalid attribute type "
+
std
::
to_string
(
attr
.
type
()));
...
@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT8E4M3FNUZ
:
{
std
::
vector
<
int32_t
>
data_int32
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
data_fp8
;
std
::
transform
(
data_int32
.
begin
(),
data_int32
.
end
(),
std
::
back_inserter
(
data_fp8
),
[](
float
raw_val
)
{
return
migraphx
::
fp8
::
fp8e4m3fnuz
{
raw_val
};
});
return
create_literal
(
shape
::
fp8e4m3fnuz_type
,
dims
,
data_fp8
);
}
case
onnx
::
TensorProto
::
FLOAT8E5M2FNUZ
:
case
onnx
::
TensorProto
::
FLOAT8E5M2
:
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
...
@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
...
@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case
11
:
return
shape
::
double_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
case
13
:
return
shape
::
uint64_type
;
case
18
:
return
shape
::
fp8e4m3fnuz_type
;
case
14
:
case
15
:
case
16
:
case
17
:
case
19
:
case
20
:
default:
{
default:
{
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
}
}
...
...
src/onnx/parse_lstm.cpp
View file @
f8a75f8a
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
}
}
void
lstm_transpose_inputs
(
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
0
]);
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_undefined
())
{
args
[
5
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
5
]);
}
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_undefined
())
{
args
[
6
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
6
]);
}
}
void
lstm_transpose_outputs
(
onnx_parser
::
node_info
&
info
,
instruction_ref
&
hidden_states
,
instruction_ref
&
last_output
,
instruction_ref
&
last_cell_output
)
{
std
::
vector
<
int64_t
>
perm_hs
{
2
,
0
,
1
,
3
};
hidden_states
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hs
}}),
hidden_states
);
std
::
vector
<
int64_t
>
perm_last
{
1
,
0
,
2
};
last_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_output
);
last_cell_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_cell_output
);
}
struct
parse_lstm
:
op_parser
<
parse_lstm
>
struct
parse_lstm
:
op_parser
<
parse_lstm
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
}
int
layout
=
0
;
if
(
contains
(
info
.
attributes
,
"layout"
))
{
layout
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"layout"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
if
(
args
.
size
()
<
8
)
{
{
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
}
if
(
layout
!=
0
)
{
lstm_transpose_inputs
(
info
,
args
);
}
// first output for concatenation of hidden states
// first output for concatenation of hidden states
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto
last_cell_output
=
auto
last_cell_output
=
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
if
(
layout
!=
0
)
{
lstm_transpose_outputs
(
info
,
hidden_states
,
last_output
,
last_cell_output
);
}
return
{
hidden_states
,
last_output
,
last_cell_output
};
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
}
};
};
...
...
src/onnx/parse_multinomial.cpp
View file @
f8a75f8a
...
@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
...
@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
// number of elements is batch_size * sample_size .
size_t
batch_size
=
s0
.
lens
().
front
();
size_t
batch_size
=
s0
.
lens
().
front
();
auto
rand_dummy
=
info
.
add_literal
(
auto
rand_dummy
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
batch_size
*
sample_size
}}
);
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}}
,
std
::
vector
<
float
>
(
batch_size
*
sample_size
)});
randoms
=
randoms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
info
.
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
}
}
...
...
Prev
1
2
3
4
5
6
7
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment