Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f8a75f8a
"vscode:/vscode.git/clone" did not exist on "25dc3235224e20040d4fb1369a60e9fe12cb7b3e"
Commit
f8a75f8a
authored
Dec 07, 2023
by
Paul
Browse files
Merge
parents
74448ed6
d00fdf6e
Changes
242
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1037 additions
and
163 deletions
+1037
-163
src/include/migraphx/op/scatternd_max.hpp
src/include/migraphx/op/scatternd_max.hpp
+47
-0
src/include/migraphx/op/scatternd_min.hpp
src/include/migraphx/op/scatternd_min.hpp
+47
-0
src/include/migraphx/op/scatternd_op.hpp
src/include/migraphx/op/scatternd_op.hpp
+3
-2
src/include/migraphx/op/slice.hpp
src/include/migraphx/op/slice.hpp
+1
-0
src/include/migraphx/op/unary.hpp
src/include/migraphx/op/unary.hpp
+5
-4
src/include/migraphx/op/unique.hpp
src/include/migraphx/op/unique.hpp
+323
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+3
-0
src/include/migraphx/par.hpp
src/include/migraphx/par.hpp
+133
-0
src/include/migraphx/par_for.hpp
src/include/migraphx/par_for.hpp
+7
-77
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+3
-1
src/include/migraphx/simple_par_for.hpp
src/include/migraphx/simple_par_for.hpp
+119
-0
src/include/migraphx/tune_axis.hpp
src/include/migraphx/tune_axis.hpp
+8
-8
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+15
-5
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+9
-2
src/onnx/include/migraphx/onnx/pooling.hpp
src/onnx/include/migraphx/onnx/pooling.hpp
+47
-0
src/onnx/onnx.proto
src/onnx/onnx.proto
+193
-61
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+23
-0
src/onnx/parse_lstm.cpp
src/onnx/parse_lstm.cpp
+47
-0
src/onnx/parse_multinomial.cpp
src/onnx/parse_multinomial.cpp
+3
-3
No files found.
src/include/migraphx/op/scatternd_max.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MAX_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatternd_max
:
scatternd_op
<
scatternd_max
>
{
scatternd_max
()
{}
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
std
::
max
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatternd_min.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MIN_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatternd_min
:
scatternd_op
<
scatternd_min
>
{
scatternd_min
()
{}
auto
reduction
()
const
{
return
[](
auto
&
x
,
const
auto
&
y
)
{
x
=
std
::
min
(
x
,
y
);
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/op/scatternd_op.hpp
View file @
f8a75f8a
...
@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
...
@@ -121,7 +121,8 @@ struct scatternd_op : op_name<Derived>
auto
k
=
indices_shape
.
lens
().
back
();
auto
k
=
indices_shape
.
lens
().
back
();
auto
q
=
indices_shape
.
ndim
();
auto
q
=
indices_shape
.
ndim
();
auto
r
=
dyn_out
.
computed_shape
.
ndim
();
auto
r
=
dyn_out
.
computed_shape
.
ndim
();
par_for
(
updates_shape
.
elements
(),
[
&
](
const
auto
i
)
{
for
(
auto
i
=
0u
;
i
<
updates_shape
.
elements
();
++
i
)
{
auto
updates_idx
=
updates_std
.
multi
(
i
);
auto
updates_idx
=
updates_std
.
multi
(
i
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
vector
<
std
::
size_t
>
indices_idx
(
q
,
0
);
std
::
copy
(
std
::
copy
(
...
@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
...
@@ -135,7 +136,7 @@ struct scatternd_op : op_name<Derived>
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
std
::
copy
(
updates_idx
.
begin
()
+
q
-
1
,
updates_idx
.
end
(),
out_idx
.
begin
()
+
k
);
self
.
reduction
()(
output
[
dyn_out
.
computed_shape
.
index
(
out_idx
)],
updates
[
i
]);
self
.
reduction
()(
output
[
dyn_out
.
computed_shape
.
index
(
out_idx
)],
updates
[
i
]);
}
);
}
});
});
});
});
...
...
src/include/migraphx/op/slice.hpp
View file @
f8a75f8a
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <array>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/unary.hpp
View file @
f8a75f8a
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -84,7 +85,7 @@ struct unary : op_name<Derived>
...
@@ -84,7 +85,7 @@ struct unary : op_name<Derived>
argument
result
{
dyn_out
.
computed_shape
};
argument
result
{
dyn_out
.
computed_shape
};
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
([
&
](
auto
output
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
args
[
0
].
visit
([
&
](
auto
input
)
{
std
::
transform
(
input
.
begin
(),
par_
transform
(
input
.
begin
(),
input
.
end
(),
input
.
end
(),
output
.
begin
(),
output
.
begin
(),
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
static_cast
<
const
Derived
&>
(
*
this
).
apply
());
...
...
src/include/migraphx/op/unique.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNIQUE_HPP
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/tune_axis.hpp>
#include <utility>
#include <map>
#include <limits>
#include <optional>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
// https://onnx.ai/onnx/operators/onnx__Unique.html
// The Onnx spec refers to numpy specification, used as a reference:
// https://numpy.org/doc/stable/reference/generated/numpy.unique.html
// Input : Given an array of elements : X.
// Output(s) :
// 1. Find the unique elements (Y) of input (X).
//
// There are three outputs in addition to the unique elements in Y:
// 2. the indices of the input array that give the unique values
// 3. the indices of the unique array that reconstruct the input array
// 4. the number of times each unique value comes up in the input array
// Optional Attribute: 'Sorted' = 1 for sorted; = 0 for unsorted.
// Onnx specification makes 'sorted' a default, while Numpy always sorts.
//
// Optional Attribute: 'Axis' is 'None' (default) or a valid int < rank(X).
// Negative values are allowed.
//
// Numpy has the following important note on Axis:
// ------------------------------------------------------------------
// When an axis is specified the subarrays indexed by the axis are
// sorted. This is done by making the specified axis the first
// dimension of the array (move the axis to the first dimension to
// keep the order of the other axes) and then flattening the subarrays
// in C order. The flattened subarrays are then viewed as a structured
// type with each element given a label, with the effect that we end
// up with a 1-D array of structured types that can be treated in the
// same way as any other 1-D array. The result is that the flattened
// subarrays are sorted in lexicographic order starting with the first
// element.
// ------------------------------------------------------------------
struct
unique
{
template
<
class
T
>
auto
make_idx_less_fn
(
const
T
&
data
,
size_t
chunk_sz
)
const
{
return
[
&
data
,
chunk_sz
](
auto
idx1
,
auto
idx2
)
{
return
std
::
lexicographical_compare
(
data
.
begin
()
+
idx1
,
data
.
begin
()
+
idx1
+
chunk_sz
,
data
.
begin
()
+
idx2
,
data
.
begin
()
+
idx2
+
chunk_sz
);
};
}
// CASE SORTED:
//
// To process into a sorted unique series of elements/chunks:
// Chunk size == 1 means a simple element; >1 means a flat representation.
// Steps: first go through the input elements/chunks for uniqueness.
// At the end of this processing, per the sorted sequence of unique elements:
// update/create data structures: y, y_indices, x_rev_indices, y_count
//
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 1;
// OUTPUT(s): indices..
// y_indices: [1, 0, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [1, 0, 0, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [2, 1, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [1, 2, 3, 4] --- the unique output is constructed from x[y_indices[...]]
template
<
class
T
>
auto
sorted_uniq_indices
(
const
T
&
input_data
,
size_t
chunk_sz
)
const
{
struct
y_info
{
size_t
y_idx
;
size_t
x_idx
;
size_t
ct
=
0
;
};
auto
idx_less_fn
=
make_idx_less_fn
(
input_data
,
chunk_sz
);
std
::
map
<
size_t
,
y_info
,
decltype
(
idx_less_fn
)
>
uniq_val_map
(
idx_less_fn
);
std
::
tuple
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>>
rv
;
auto
&
[
y_indices
,
x_rev_indices
,
y_count
]
=
rv
;
// go through all the elements and find the unique elements..
size_t
count_x
=
input_data
.
size
();
for
(
size_t
f_idx
=
0
,
x_idx
=
0
;
f_idx
<
count_x
;
f_idx
+=
chunk_sz
,
x_idx
++
)
{
y_info
entry
=
{.
y_idx
=
uniq_val_map
.
size
(),
.
x_idx
=
x_idx
};
auto
[
itr
,
added_new
]
=
uniq_val_map
.
insert
({
f_idx
,
entry
});
itr
->
second
.
ct
++
;
x_rev_indices
.
push_back
(
itr
->
second
.
y_idx
);
}
std
::
vector
<
std
::
size_t
>
y2x_indices
(
uniq_val_map
.
size
());
y_indices
.
resize
(
uniq_val_map
.
size
());
y_count
.
resize
(
uniq_val_map
.
size
());
size_t
idx
=
0
;
// the unique elements are now sorted:
// post-processing for all the return indices.
for
(
const
auto
&
v
:
uniq_val_map
)
{
y2x_indices
[
v
.
second
.
y_idx
]
=
idx
;
y_indices
[
idx
]
=
v
.
second
.
x_idx
;
y_count
[
idx
]
=
v
.
second
.
ct
;
idx
++
;
}
// update x_rev_indices as per the sorted order of y_indices
for
(
auto
&
i
:
x_rev_indices
)
i
=
y2x_indices
[
i
];
return
rv
;
}
// CASE UNSORTED:
//
// To process into an un-sorted unique series of elements/chunks:
// For chunk size = 1 is a simple element, else use a flat representation of a tensor obj
// Go through the input elements/chunks one by one with inline processing of indices..
// INPUT x: [2, 1, 1, 3, 4, 3], attr_sorted = 0;
// OUTPUT(s): indices..
// y_indices: [0, 1, 3, 4] --- first incidence, in terms of index in sequence x
// x_rev_indices: [0, 1, 1, 2, 3, 2] --- x seen in terms of indices of unique sequence y
// y_count: [1, 2, 2, 1] -- count at each y_index. sum = len(x)
// NOTE: y [2, 1, 3, 4] --- the unique output is constructed from x[y_indices[...]]
// Output data structures: y_indices, x_rev_indices, y_count are processed inline.
template
<
class
T
>
auto
unsorted_uniq_indices
(
const
T
&
input_data
,
size_t
chunk_sz
)
const
{
auto
idx_less_fn
=
make_idx_less_fn
(
input_data
,
chunk_sz
);
std
::
map
<
size_t
,
size_t
,
decltype
(
idx_less_fn
)
>
uniq_val_map
(
idx_less_fn
);
// rv is used for NVRO below..
std
::
tuple
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
std
::
size_t
>>
rv
;
auto
&
[
y_indices
,
x_rev_indices
,
y_count
]
=
rv
;
// go through all the elements and add the unique elements into the map..
// inline processing for outputs: y_indices, x_rev_indices, y_count
size_t
count_x
=
input_data
.
size
();
for
(
size_t
f_idx
=
0
;
f_idx
<
count_x
;
f_idx
+=
chunk_sz
)
{
auto
[
itr
,
added_new
]
=
uniq_val_map
.
insert
({
f_idx
,
y_indices
.
size
()});
if
(
added_new
)
{
y_count
.
push_back
(
0
);
y_indices
.
push_back
(
x_rev_indices
.
size
());
}
y_count
[
itr
->
second
]
++
;
x_rev_indices
.
push_back
(
itr
->
second
);
}
return
rv
;
}
// Axis. Default: none. Range: [-rank, rank-1]
std
::
optional
<
int64_t
>
axis
;
// Sorted, Default: 1= sorted. 0 = unsorted.
bool
sorted
=
true
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
),
f
(
self
.
sorted
,
"sorted"
));
}
std
::
string
name
()
const
{
return
"unique"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&
sh_x
=
inputs
[
0
];
auto
lens_x
=
sh_x
.
lens
();
size_t
dim_x
=
sh_x
.
ndim
();
size_t
max_uniq_ct
=
sh_x
.
elements
();
std
::
vector
<
shape
::
dynamic_dimension
>
d_out
;
if
(
axis
)
{
int64_t
t_axis
=
migraphx
::
tune_axis
(
dim_x
,
*
axis
,
name
());
if
(
t_axis
!=
0
)
MIGRAPHX_THROW
(
"Unique: Only supports axis = 0 or None"
);
d_out
=
sh_x
.
to_dynamic
().
dyn_dims
();
// only axis = 0 is supported:
max_uniq_ct
=
lens_x
[
0
];
// min = 1 unique element; max = full dimension along axis 0
d_out
[
0
]
=
{
1
,
max_uniq_ct
};
}
else
{
d_out
.
push_back
({
1
,
max_uniq_ct
});
}
shape
sh_y
=
{
sh_x
.
type
(),
d_out
};
// The three outputted Indices are just 1-D:
shape
sh_idx
{
shape
::
int64_type
,
{
d_out
[
0
]}};
return
{{
sh_y
,
sh_idx
,
sh_idx
,
sh_idx
}};
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
auto
sh_x
=
args
.
front
().
get_shape
();
auto
lens_x
=
sh_x
.
lens
();
shape
output_shape
=
dyn_out
.
computed_shape
;
auto
vec_ss
=
output_shape
.
sub_shapes
();
auto
ct_x
=
sh_x
.
elements
();
shape
sh_y
=
{
vec_ss
[
0
].
type
(),
{
ct_x
}};
shape
sh_idx
=
{
vec_ss
[
1
].
type
(),
{
ct_x
}};
shape
sh_x_idx
=
{
vec_ss
[
1
].
type
(),
{
ct_x
}};
argument
res_y
{
sh_y
};
argument
res_y_idx
{
sh_idx
};
argument
res_x_rev_idx
{
sh_idx
};
argument
res_y_ct_idx
{
sh_idx
};
std
::
vector
<
size_t
>
out_y_idx
;
std
::
vector
<
size_t
>
out_x_rev_idx
;
std
::
vector
<
size_t
>
out_y_ct
;
// If axis is not none, for >1D tensors, we have to consider
// then, the uniqueness of chunks of sub-tensors: a subsequence of built-ins..
// For a built-in type, chunk_sz is of course = 1
size_t
chunk_sz
=
1
;
if
(
axis
)
chunk_sz
=
ct_x
/
lens_x
[
0
];
// axis = 0 is supported.
visit_all
(
args
.
front
(),
res_y
)([
&
](
auto
x
,
auto
y_flat
)
{
using
o_type
=
typename
decltype
(
x
)
::
value_type
;
std
::
vector
<
o_type
>
x_in
(
x
.
begin
(),
x
.
end
());
std
::
tie
(
out_y_idx
,
out_x_rev_idx
,
out_y_ct
)
=
sorted
?
sorted_uniq_indices
(
x_in
,
chunk_sz
)
:
unsorted_uniq_indices
(
x_in
,
chunk_sz
);
const
auto
uniq_ct
=
out_y_idx
.
size
();
// construct y from x[indices] in flattened form
// later we reshape y to the final shape..
auto
y_dst
=
y_flat
.
begin
();
for
(
size_t
idx
=
0
;
idx
<
uniq_ct
;
idx
++
)
y_dst
=
copy_n
(
x_in
.
begin
()
+
out_y_idx
[
idx
]
*
chunk_sz
,
chunk_sz
,
y_dst
);
std
::
vector
<
size_t
>
lens_y
;
// if axis is specified:
// the output shape keeps the n-1 dimensions of x
if
(
axis
)
{
lens_y
=
lens_x
;
lens_y
[
0
]
=
uniq_ct
;
}
else
{
lens_y
=
{
uniq_ct
};
}
sh_y
=
{
sh_y
.
type
(),
lens_y
};
sh_idx
=
{
sh_idx
.
type
(),
{
uniq_ct
}};
});
visit_all
(
res_y_idx
,
res_x_rev_idx
,
res_y_ct_idx
)(
[
&
](
auto
y_indices
,
auto
x_rev_indices
,
auto
y_count
)
{
std
::
copy
(
out_y_idx
.
begin
(),
out_y_idx
.
end
(),
y_indices
.
begin
());
std
::
copy
(
out_x_rev_idx
.
begin
(),
out_x_rev_idx
.
end
(),
x_rev_indices
.
begin
());
std
::
copy
(
out_y_ct
.
begin
(),
out_y_ct
.
end
(),
y_count
.
begin
());
sh_x_idx
=
{
sh_idx
.
type
(),
{
out_x_rev_idx
.
size
()}};
});
return
{{
res_y
.
reshape
(
sh_y
),
res_y_idx
.
reshape
(
sh_idx
),
res_x_rev_idx
.
reshape
(
sh_x_idx
),
res_y_ct_idx
.
reshape
(
sh_idx
)}};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
f8a75f8a
...
@@ -119,6 +119,8 @@
...
@@ -119,6 +119,8 @@
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_add.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_none.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_mul.hpp>
#include <migraphx/op/scatternd_max.hpp>
#include <migraphx/op/scatternd_min.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
#include <migraphx/op/sinh.hpp>
...
@@ -137,6 +139,7 @@
...
@@ -137,6 +139,7 @@
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unique.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#include <migraphx/op/where.hpp>
...
...
src/include/migraphx/par.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
#include <migraphx/config.hpp>
#if MIGRAPHX_HAS_EXECUTORS
#include <execution>
#else
#include <migraphx/simple_par_for.hpp>
#endif
#include <algorithm>
#include <mutex>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
detail
{
struct
exception_list
{
std
::
vector
<
std
::
exception_ptr
>
exceptions
;
std
::
mutex
m
;
void
add_exception
()
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
m
);
exceptions
.
push_back
(
std
::
current_exception
());
}
template
<
class
F
>
auto
collect
(
F
f
)
{
return
[
f
,
this
](
auto
&&
...
xs
)
{
try
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
}
catch
(...)
{
this
->
add_exception
();
}
};
}
void
throw_if_exception
()
const
{
if
(
not
exceptions
.
empty
())
std
::
rethrow_exception
(
exceptions
.
front
());
}
};
}
// namespace detail
template
<
class
InputIt
,
class
OutputIt
,
class
UnaryOperation
>
OutputIt
par_transform
(
InputIt
first1
,
InputIt
last1
,
OutputIt
d_first
,
UnaryOperation
unary_op
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
transform
(
std
::
execution
::
par
,
first1
,
last1
,
d_first
,
std
::
move
(
unary_op
));
#else
simple_par_for
(
last1
-
first1
,
[
&
](
auto
i
)
{
d_first
[
i
]
=
unary_op
(
first1
[
i
]);
});
return
d_first
+
(
last1
-
first1
);
#endif
}
template
<
class
InputIt1
,
class
InputIt2
,
class
OutputIt
,
class
BinaryOperation
>
OutputIt
par_transform
(
InputIt1
first1
,
InputIt1
last1
,
InputIt2
first2
,
OutputIt
d_first
,
BinaryOperation
binary_op
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
transform
(
std
::
execution
::
par
,
first1
,
last1
,
first2
,
d_first
,
std
::
move
(
binary_op
));
#else
simple_par_for
(
last1
-
first1
,
[
&
](
auto
i
)
{
d_first
[
i
]
=
binary_op
(
first1
[
i
],
first2
[
i
]);
});
return
d_first
+
(
last1
-
first1
);
#endif
}
template
<
class
InputIt
,
class
UnaryFunction
>
void
par_for_each
(
InputIt
first
,
InputIt
last
,
UnaryFunction
f
)
{
#if MIGRAPHX_HAS_EXECUTORS
// Propagate the exception
detail
::
exception_list
ex
;
std
::
for_each
(
std
::
execution
::
par
,
first
,
last
,
ex
.
collect
(
std
::
move
(
f
)));
ex
.
throw_if_exception
();
#else
simple_par_for
(
last
-
first
,
[
&
](
auto
i
)
{
f
(
first
[
i
]);
});
#endif
}
template
<
class
...
Ts
>
auto
par_copy_if
(
Ts
&&
...
xs
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
copy_if
(
std
::
execution
::
par
,
std
::
forward
<
Ts
>
(
xs
)...);
#else
return
std
::
copy_if
(
std
::
forward
<
Ts
>
(
xs
)...);
#endif
}
template
<
class
...
Ts
>
auto
par_sort
(
Ts
&&
...
xs
)
{
#if MIGRAPHX_HAS_EXECUTORS
return
std
::
sort
(
std
::
execution
::
par
,
std
::
forward
<
Ts
>
(
xs
)...);
#else
return
std
::
sort
(
std
::
forward
<
Ts
>
(
xs
)...);
#endif
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PAR_HPP
src/include/migraphx/par_for.hpp
View file @
f8a75f8a
...
@@ -24,93 +24,23 @@
...
@@ -24,93 +24,23 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAR_FOR_HPP
#include <thread>
#include <migraphx/par.hpp>
#include <cmath>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
tid
,
F
f
)
->
decltype
(
f
(
i
,
tid
))
{
f
(
i
,
tid
);
}
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
,
F
f
)
->
decltype
(
f
(
i
))
{
f
(
i
);
}
template
<
class
F
>
void
par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
thread_invoke
(
i
,
0
,
f
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
size_t
tid
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
,
&
tid
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
thread_invoke
(
i
,
tid
,
f
);
}
});
work
+=
grainsize
;
++
tid
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
void
par_for
(
std
::
size_t
n
,
F
f
)
{
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
using
iterator
=
basic_iota_iterator
<
id
,
std
::
size_t
>
;
n
/
std
::
max
<
std
::
size_t
>
(
1
,
min_grain
));
par_for_each
(
iterator
{
0
,
{}},
iterator
{
n
,
{}},
f
);
par_for_impl
(
n
,
threadsize
,
f
);
}
}
template
<
class
F
>
template
<
class
F
>
void
par_for
(
std
::
size_t
n
,
F
f
)
void
par_for
(
std
::
size_t
n
,
std
::
size_t
,
F
f
)
{
{
const
int
min_grain
=
8
;
par_for
(
n
,
f
);
par_for
(
n
,
min_grain
,
f
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_pooling.hpp
View file @
f8a75f8a
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <string>
#include <string>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/shape.hpp
View file @
f8a75f8a
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
...
@@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape
m(int32_type, int32_t) \
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t)
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
// clang-format on
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
...
...
src/include/migraphx/simple_par_for.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLE_PAR_FOR_HPP
#include <thread>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cassert>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
joinable_thread
:
std
::
thread
{
template
<
class
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
// NOLINT
{
}
joinable_thread
&
operator
=
(
joinable_thread
&&
other
)
=
default
;
joinable_thread
(
joinable_thread
&&
other
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
tid
,
F
f
)
->
decltype
(
f
(
i
,
tid
))
{
f
(
i
,
tid
);
}
template
<
class
F
>
auto
thread_invoke
(
std
::
size_t
i
,
std
::
size_t
,
F
f
)
->
decltype
(
f
(
i
))
{
f
(
i
);
}
template
<
class
F
>
void
simple_par_for_impl
(
std
::
size_t
n
,
std
::
size_t
threadsize
,
F
f
)
{
if
(
threadsize
<=
1
)
{
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
thread_invoke
(
i
,
0
,
f
);
}
else
{
std
::
vector
<
joinable_thread
>
threads
(
threadsize
);
// Using const here causes gcc 5 to ICE
#if(!defined(__GNUC__) || __GNUC__ != 5)
const
#endif
std
::
size_t
grainsize
=
std
::
ceil
(
static_cast
<
double
>
(
n
)
/
threads
.
size
());
std
::
size_t
work
=
0
;
std
::
size_t
tid
=
0
;
std
::
generate
(
threads
.
begin
(),
threads
.
end
(),
[
=
,
&
work
,
&
tid
]
{
auto
result
=
joinable_thread
([
=
]
{
std
::
size_t
start
=
work
;
std
::
size_t
last
=
std
::
min
(
n
,
work
+
grainsize
);
for
(
std
::
size_t
i
=
start
;
i
<
last
;
i
++
)
{
thread_invoke
(
i
,
tid
,
f
);
}
});
work
+=
grainsize
;
++
tid
;
return
result
;
});
assert
(
work
>=
n
);
}
}
template
<
class
F
>
void
simple_par_for
(
std
::
size_t
n
,
std
::
size_t
min_grain
,
F
f
)
{
const
auto
threadsize
=
std
::
min
<
std
::
size_t
>
(
std
::
thread
::
hardware_concurrency
(),
n
/
std
::
max
<
std
::
size_t
>
(
1
,
min_grain
));
simple_par_for_impl
(
n
,
threadsize
,
f
);
}
template
<
class
F
>
void
simple_par_for
(
std
::
size_t
n
,
F
f
)
{
const
int
min_grain
=
8
;
simple_par_for
(
n
,
min_grain
,
f
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tune_axis.hpp
View file @
f8a75f8a
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -24,21 +24,21 @@
...
@@ -24,21 +24,21 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#define MIGRAPHX_GUARD_OPERATORS_TUNE_AXIS_HPP
#include <utility>
#include <cstdint>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
int
tune_axis
(
const
int
n_dim
,
const
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
inline
int
tune_axis
(
int
n_dim
,
int
axis
,
const
std
::
string
&
op_name
=
"OPERATOR"
)
{
{
if
(
axis
>=
n_dim
or
std
::
abs
(
axis
)
>
n_dim
)
if
(
axis
<
0
)
{
axis
+=
n_dim
;
if
(
axis
<
0
or
axis
>=
n_dim
)
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
MIGRAPHX_THROW
(
to_upper
(
op_name
)
+
": axis is out of range."
);
}
return
(
axis
<
0
)
?
axis
+
n_dim
:
axis
;
return
axis
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/type_traits.hpp
View file @
f8a75f8a
...
@@ -28,25 +28,35 @@
...
@@ -28,25 +28,35 @@
#include <type_traits>
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#define MIGRAPHX_DETAIL_
EXTEND
_TRAIT
_FOR
(trait
, T
) \
#define MIGRAPHX_DETAIL_
DEFINE
_TRAIT(trait) \
template <class X> \
template <class X> \
struct trait : std::trait<X> \
struct trait : std::trait<X> \
{ \
{ \
}; \
};
\
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <> \
template <> \
struct trait<T> : std::true_type \
struct trait<T> : std::true_type \
{ \
{ \
};
};
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_floating_point
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_arithmetic
);
MIGRAPHX_DETAIL_DEFINE_TRAIT
(
is_signed
);
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
migraphx
::
fp8
::
fp8e4m3fnuz
)
template
<
class
T
>
template
<
class
T
>
using
accumulator_type
=
using
accumulator_type
=
std
::
conditional_t
<
is_floating_point
<
T
>
{},
std
::
conditional_t
<
is_floating_point
<
T
>
{},
...
...
src/onnx/CMakeLists.txt
View file @
f8a75f8a
...
@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
...
@@ -26,7 +26,11 @@ find_package(Protobuf REQUIRED)
protobuf_generate_cpp
(
PROTO_SRCS PROTO_HDRS onnx.proto
)
protobuf_generate_cpp
(
PROTO_SRCS PROTO_HDRS onnx.proto
)
add_library
(
onnx-proto STATIC
${
PROTO_SRCS
}
)
add_library
(
onnx-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
onnx-proto PRIVATE -w
)
if
(
MSVC
)
target_compile_options
(
onnx-proto PRIVATE /w
)
else
()
target_compile_options
(
onnx-proto PRIVATE -w
)
endif
()
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
set_target_properties
(
onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
...
@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
...
@@ -37,7 +41,10 @@ set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
migraphx_generate_export_header
(
migraphx_onnx
)
migraphx_generate_export_header
(
migraphx_onnx
)
rocm_set_soversion
(
migraphx_onnx
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_onnx
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_onnx
)
rocm_clang_tidy_check
(
migraphx_onnx
)
target_link_libraries
(
migraphx_onnx PRIVATE onnx-proto
"-Wl,--exclude-libs,ALL"
)
target_link_libraries
(
migraphx_onnx PRIVATE onnx-proto
)
if
(
NOT WIN32
)
target_link_libraries
(
migraphx_onnx PRIVATE
"-Wl,--exclude-libs,ALL"
)
endif
()
target_link_libraries
(
migraphx_onnx PUBLIC migraphx
)
target_link_libraries
(
migraphx_onnx PUBLIC migraphx
)
rocm_install_targets
(
rocm_install_targets
(
...
...
src/onnx/include/migraphx/onnx/pooling.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
value
handle_pooling_values
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
value
values
);
instruction_ref
add_pooling_op
(
const
op_desc
&
opd
,
onnx_parser
::
node_info
info
,
instruction_ref
l0
);
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/onnx/onnx.proto
View file @
f8a75f8a
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
//
//
//
Copyright (c) ONNX Project Contributors.
//
SPDX-License-Identifier: Apache-2.0
// Licensed under the MIT license.
syntax
=
"proto2"
;
syntax
=
"proto2"
;
...
@@ -27,13 +27,6 @@ package onnx_for_migraphx;
...
@@ -27,13 +27,6 @@ package onnx_for_migraphx;
// Notes
// Notes
//
//
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility
// Protobuf compatibility
//
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
...
@@ -92,15 +85,28 @@ enum Version {
...
@@ -92,15 +85,28 @@ enum Version {
// - Add sparse initializers
// - Add sparse initializers
IR_VERSION_2019_9_19
=
0x0000000000000006
;
IR_VERSION_2019_9_19
=
0x0000000000000006
;
// IR VERSION 7 published on <TBD>
// IR VERSION 7 published on May 8, 2020
// - Add support to allow function body graph to rely on multiple external opreator sets.
// - Add a list to promote inference graph's initializers to global and
// - Add a list to promote inference graph's initializers to global and
// mutable variables. Global variables are visible in all graphs of the
// mutable variables. Global variables are visible in all graphs of the
// stored models.
// stored models.
// - Add message TrainingInfoProto to store initialization
// - Add message TrainingInfoProto to store initialization
// method and training algorithm. The execution of TrainingInfoProto
// method and training algorithm. The execution of TrainingInfoProto
// can modify the values of mutable variables.
// can modify the values of mutable variables.
// - Make inference graph callable from TrainingInfoProto via GraphCall operator.
// - Implicitly add inference graph into each TrainingInfoProto's algorithm.
IR_VERSION
=
0x0000000000000007
;
IR_VERSION_2020_5_8
=
0x0000000000000007
;
// IR VERSION 8 published on July 30, 2021
// Introduce TypeProto.SparseTensor
// Introduce TypeProto.Optional
// Added a list of FunctionProtos local to the model
// Deprecated since_version and operator status from FunctionProto
IR_VERSION_2021_7_30
=
0x0000000000000008
;
// IR VERSION 9 published on TBD
// Added AttributeProto to FunctionProto so that default attribute values can be set.
// Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
IR_VERSION
=
0x0000000000000009
;
}
}
// Attributes
// Attributes
...
@@ -121,6 +127,7 @@ message AttributeProto {
...
@@ -121,6 +127,7 @@ message AttributeProto {
TENSOR
=
4
;
TENSOR
=
4
;
GRAPH
=
5
;
GRAPH
=
5
;
SPARSE_TENSOR
=
11
;
SPARSE_TENSOR
=
11
;
TYPE_PROTO
=
13
;
FLOATS
=
6
;
FLOATS
=
6
;
INTS
=
7
;
INTS
=
7
;
...
@@ -128,6 +135,7 @@ message AttributeProto {
...
@@ -128,6 +135,7 @@ message AttributeProto {
TENSORS
=
9
;
TENSORS
=
9
;
GRAPHS
=
10
;
GRAPHS
=
10
;
SPARSE_TENSORS
=
12
;
SPARSE_TENSORS
=
12
;
TYPE_PROTOS
=
14
;
}
}
// The name field MUST be present for this version of the IR.
// The name field MUST be present for this version of the IR.
...
@@ -159,6 +167,7 @@ message AttributeProto {
...
@@ -159,6 +167,7 @@ message AttributeProto {
optional
SparseTensorProto
sparse_tensor
=
22
;
// sparse tensor value
optional
SparseTensorProto
sparse_tensor
=
22
;
// sparse tensor value
// Do not use field below, it's deprecated.
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
// optional ValueProto v = 12; // value - subsumes everything but graph
optional
TypeProto
tp
=
14
;
// type proto
repeated
float
floats
=
7
;
// list of floats
repeated
float
floats
=
7
;
// list of floats
repeated
int64
ints
=
8
;
// list of ints
repeated
int64
ints
=
8
;
// list of ints
...
@@ -166,6 +175,7 @@ message AttributeProto {
...
@@ -166,6 +175,7 @@ message AttributeProto {
repeated
TensorProto
tensors
=
10
;
// list of tensors
repeated
TensorProto
tensors
=
10
;
// list of tensors
repeated
GraphProto
graphs
=
11
;
// list of graph
repeated
GraphProto
graphs
=
11
;
// list of graph
repeated
SparseTensorProto
sparse_tensors
=
23
;
// list of sparse tensors
repeated
SparseTensorProto
sparse_tensors
=
23
;
// list of sparse tensors
repeated
TypeProto
type_protos
=
15
;
// list of type protos
}
}
// Defines information on value, including the name, the type, and
// Defines information on value, including the name, the type, and
...
@@ -211,7 +221,7 @@ message NodeProto {
...
@@ -211,7 +221,7 @@ message NodeProto {
// TrainingInfoProto stores information for training a model.
// TrainingInfoProto stores information for training a model.
// In particular, this defines two functionalities: an initialization-step
// In particular, this defines two functionalities: an initialization-step
// and a training-algorithm-step. Initialization resets the model
// and a training-algorithm-step. Initialization resets the model
// back to its original state as if no training has been
consu
med.
// back to its original state as if no training has been
perfor
med.
// Training algorithm improves the model based on input data.
// Training algorithm improves the model based on input data.
//
//
// The semantics of the initialization-step is that the initializers
// The semantics of the initialization-step is that the initializers
...
@@ -224,8 +234,8 @@ message NodeProto {
...
@@ -224,8 +234,8 @@ message NodeProto {
// training algorithm's step. After the execution of a
// training algorithm's step. After the execution of a
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
// may be immediately updated. If the targeted training algorithm contains
// may be immediately updated. If the targeted training algorithm contains
// consecutive update st
ag
es (such as block coordinate descent methods),
// consecutive update ste
p
s (such as block coordinate descent methods),
// the user needs to create a TrainingInfoProto for each st
ag
e.
// the user needs to create a TrainingInfoProto for each ste
p
.
message
TrainingInfoProto
{
message
TrainingInfoProto
{
// This field describes a graph to compute the initial tensors
// This field describes a graph to compute the initial tensors
// upon starting the training process. Initialization graph has no input
// upon starting the training process. Initialization graph has no input
...
@@ -239,20 +249,38 @@ message TrainingInfoProto {
...
@@ -239,20 +249,38 @@ message TrainingInfoProto {
// iteration to zero.
// iteration to zero.
//
//
// By default, this field is an empty graph and its evaluation does not
// By default, this field is an empty graph and its evaluation does not
// produce any output.
// produce any output.
Thus, no initializer would be changed by default.
optional
GraphProto
initialization
=
1
;
optional
GraphProto
initialization
=
1
;
// This field represents a training algorithm step. Given required inputs,
// This field represents a training algorithm step. Given required inputs,
// it computes outputs to update initializers in its own or inference graph's
// it computes outputs to update initializers in its own or inference graph's
// initializer lists. In general, this graph contains loss node, gradient node,
// initializer lists. In general, this field contains loss node, gradient node,
// optimizer node, increment of iteration count, and some calls to the inference
// optimizer node, increment of iteration count.
// graph.
//
//
// The field algorithm.node is the only place the user can use GraphCall
// An execution of the training algorithm step is performed by executing the
// operator. The only callable graph is the one stored in ModelProto.graph.
// graph obtained by combining the inference graph (namely "ModelProto.graph")
// and the "algorithm" graph. That is, the actual the actual
// input/initializer/output/node/value_info/sparse_initializer list of
// the training graph is the concatenation of
// "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
// and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
// in that order. This combined graph must satisfy the normal ONNX conditions.
// Now, let's provide a visualization of graph combination for clarity.
// Let the inference graph (i.e., "ModelProto.graph") be
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
// and the "algorithm" graph be
// tensor_d -> Add -> tensor_e
// The combination process results
// tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
//
// Notice that an input of a node in the "algorithm" graph may reference the
// output of a node in the inference graph (but not the other way round). Also, inference
// node cannot reference inputs of "algorithm". With these restrictions, inference graph
// can always be run independently without training information.
//
//
// By default, this field is an empty graph and its evaluation does not
// By default, this field is an empty graph and its evaluation does not
// produce any output.
// produce any output. Evaluating the default training step never
// update any initializers.
optional
GraphProto
algorithm
=
2
;
optional
GraphProto
algorithm
=
2
;
// This field specifies the bindings from the outputs of "initialization" to
// This field specifies the bindings from the outputs of "initialization" to
...
@@ -284,23 +312,16 @@ message TrainingInfoProto {
...
@@ -284,23 +312,16 @@ message TrainingInfoProto {
// be multiple key-value pairs in "update_binding".
// be multiple key-value pairs in "update_binding".
//
//
// The initializers appears as keys in "update_binding" are considered
// The initializers appears as keys in "update_binding" are considered
// mutable
and globally-visible
variables. This implies some behaviors
// mutable variables. This implies some behaviors
// as described below.
// as described below.
//
//
// 1. We have only unique keys in all "update_binding"s so that two
global
// 1. We have only unique keys in all "update_binding"s so that two
// variables may not have the same name. This ensures that one
// variables may not have the same name. This ensures that one
//
global
variable is assigned up to once.
// variable is assigned up to once.
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// 2. The keys must appear in names of "ModelProto.graph.initializer" or
// "TrainingInfoProto.algorithm.initializer".
// "TrainingInfoProto.algorithm.initializer".
// 3. The values must be output names of "algorithm".
// 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
// 4. If an optional input of a graph is omitted when using GraphCall, the
// 4. Mutable variables are initialized to the value specified by the
// global variable with the same name may be used.
// 5. When using GraphCall, the users always can pass values to optional
// inputs of the called graph even if the associated initializers appears
// as keys in "update_binding"s.
// 6. The graphs in TrainingInfoProto's can use global variables as
// their operator inputs.
// 7. Mutable variables are initialized to the value specified by the
// corresponding initializer, and then potentially updated by
// corresponding initializer, and then potentially updated by
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
// "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
//
//
...
@@ -375,13 +396,31 @@ message ModelProto {
...
@@ -375,13 +396,31 @@ message ModelProto {
//
//
// If this field is empty, the training behavior of the model is undefined.
// If this field is empty, the training behavior of the model is undefined.
repeated
TrainingInfoProto
training_info
=
20
;
repeated
TrainingInfoProto
training_info
=
20
;
// A list of function protos local to the model.
//
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
// the runtimes.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto and other model local FunctionProtos.
// Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
// or by 2 FunctionProtos then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same for every node in the function body.
//
// One FunctionProto can reference other FunctionProto in the model, however, recursive reference
// is not allowed.
repeated
FunctionProto
functions
=
25
;
};
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message
StringStringEntryProto
{
message
StringStringEntryProto
{
optional
string
key
=
1
;
optional
string
key
=
1
;
optional
string
value
=
2
;
optional
string
value
=
2
;
};
};
message
TensorAnnotation
{
message
TensorAnnotation
{
...
@@ -409,8 +448,9 @@ message GraphProto {
...
@@ -409,8 +448,9 @@ message GraphProto {
optional
string
name
=
2
;
// namespace Graph
optional
string
name
=
2
;
// namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
// A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that
// Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
// MAY also appear in the input list.
// The name MUST be unique across both initializer and sparse_initializer,
// but the name MAY also appear in the input list.
repeated
TensorProto
initializer
=
5
;
repeated
TensorProto
initializer
=
5
;
// Initializers (see above) stored in sparse format.
// Initializers (see above) stored in sparse format.
...
@@ -433,13 +473,8 @@ message GraphProto {
...
@@ -433,13 +473,8 @@ message GraphProto {
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated
TensorAnnotation
quantization_annotation
=
14
;
repeated
TensorAnnotation
quantization_annotation
=
14
;
// DO NOT USE the following fields, they were deprecated from earlier versions.
reserved
3
,
4
,
6
to
9
;
// repeated string input = 3;
reserved
"ir_version"
,
"producer_version"
,
"producer_tag"
,
"domain"
;
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
}
}
// Tensors
// Tensors
...
@@ -474,6 +509,17 @@ message TensorProto {
...
@@ -474,6 +509,17 @@ message TensorProto {
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16
=
16
;
BFLOAT16
=
16
;
// Non-IEEE floating-point format based on papers
// FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
// 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
// Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
// The computation usually happens inside a block quantize / dequantize
// fused by the runtime.
FLOAT8E4M3FN
=
17
;
// float 8, mostly used for coefficients, supports nan, not inf
FLOAT8E4M3FNUZ
=
18
;
// float 8, mostly used for coefficients, supports nan, not inf, no negative zero
FLOAT8E5M2
=
19
;
// follows IEEE 754, supports nan, inf, mostly used for gradients
FLOAT8E5M2FNUZ
=
20
;
// follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
// Future extensions go here.
// Future extensions go here.
}
}
...
@@ -507,11 +553,11 @@ message TensorProto {
...
@@ -507,11 +553,11 @@ message TensorProto {
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated
float
float_data
=
4
[
packed
=
true
];
repeated
float
float_data
=
4
[
packed
=
true
];
// For int32, uint8, int8, uint16, int16, bool, and float16 values
// For int32, uint8, int8, uint16, int16, bool,
float8,
and float16 values
// float16 values must be bit-wise converted to an uint16_t prior
// float16
and float8
values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL,
or
FLOAT16
// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16
, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated
int32
int32_data
=
5
[
packed
=
true
];
repeated
int32
int32_data
=
5
[
packed
=
true
];
// For strings.
// For strings.
...
@@ -589,6 +635,8 @@ message TensorProto {
...
@@ -589,6 +635,8 @@ message TensorProto {
message
SparseTensorProto
{
message
SparseTensorProto
{
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The sequence of non-default values are encoded as a tensor of shape [NNZ].
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// The default-value is zero for numeric tensors, and empty-string for string tensors.
// values must have a non-empty name present which serves as a name for SparseTensorProto
// when used in sparse_initializer list.
optional
TensorProto
values
=
1
;
optional
TensorProto
values
=
1
;
// The indices of the non-default values, which may be stored in one of two formats.
// The indices of the non-default values, which may be stored in one of two formats.
...
@@ -619,7 +667,7 @@ message TensorShapeProto {
...
@@ -619,7 +667,7 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/ma
ster
/docs/DimensionDenotation.md#denotation-definition
// Refer to https://github.com/onnx/onnx/blob/ma
in
/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
// for pre-defined dimension denotations.
optional
string
denotation
=
3
;
optional
string
denotation
=
3
;
};
};
...
@@ -656,6 +704,23 @@ message TypeProto {
...
@@ -656,6 +704,23 @@ message TypeProto {
optional
TypeProto
value_type
=
2
;
optional
TypeProto
value_type
=
2
;
};
};
// wrapper for Tensor, Sequence, or Map
message
Optional
{
// The type and optional shape of the element wrapped.
// This field MUST be present for this version of the IR.
// Possible values correspond to OptionalProto.DataType enum
optional
TypeProto
elem_type
=
1
;
};
message
SparseTensor
{
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional
int32
elem_type
=
1
;
optional
TensorShapeProto
shape
=
2
;
}
oneof
value
{
oneof
value
{
// The type of a tensor.
// The type of a tensor.
...
@@ -672,11 +737,18 @@ message TypeProto {
...
@@ -672,11 +737,18 @@ message TypeProto {
// The type of a map.
// The type of a map.
Map
map_type
=
5
;
Map
map_type
=
5
;
// The type of an optional.
Optional
optional_type
=
9
;
// Type of the sparse tensor
SparseTensor
sparse_tensor_type
=
8
;
}
}
// An optional denotation can be used to denote the whole
// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/ma
ster
/docs/TypeDenotation.md#type-denotation-definition
// stored inside. Refer to https://github.com/onnx/onnx/blob/ma
in
/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
// for pre-defined type denotations.
optional
string
denotation
=
6
;
optional
string
denotation
=
6
;
}
}
...
@@ -696,7 +768,67 @@ message OperatorSetIdProto {
...
@@ -696,7 +768,67 @@ message OperatorSetIdProto {
optional
int64
version
=
2
;
optional
int64
version
=
2
;
}
}
// Operator/function status.
enum
OperatorStatus
{
EXPERIMENTAL
=
0
;
STABLE
=
1
;
}
message
FunctionProto
{
// The name of the function, similar usage of op_type in OperatorProto.
// Combined with FunctionProto.domain, this forms the unique identity of
// the FunctionProto.
optional
string
name
=
1
;
// Deprecated since IR Version 8
// optional int64 since_version = 2;
reserved
2
;
reserved
"since_version"
;
// Deprecated since IR Version 8
// optional OperatorStatus status = 3;
reserved
3
;
reserved
"status"
;
// The inputs and outputs of the function.
repeated
string
input
=
4
;
repeated
string
output
=
5
;
// The attribute parameters of the function.
// It is for function parameters without default values.
repeated
string
attribute
=
6
;
// The attribute protos of the function.
// It is for function attributes with default values.
// A function attribute shall be represented either as
// a string attribute or an AttributeProto, not both.
repeated
AttributeProto
attribute_proto
=
11
;
// The nodes in the function.
repeated
NodeProto
node
=
7
;
// A human-readable documentation for this function. Markdown is allowed.
optional
string
doc_string
=
8
;
// The OperatorSets this function body (graph) relies on.
//
// All nodes in the function body (graph) will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets. This means at most one version can be relied
// for one domain.
//
// The operator sets imported by FunctionProto should be compatible with the ones
// imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
// and ModelProto then versions for the operator set may be different but,
// the operator schema returned for op_type, domain, version combination
// for both the versions should be same.
repeated
OperatorSetIdProto
opset_import
=
9
;
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional
string
domain
=
10
;
}
// For using protobuf-lite
// For using protobuf-lite
option
optimize_for
=
LITE_RUNTIME
;
option
optimize_for
=
LITE_RUNTIME
;
\ No newline at end of file
src/onnx/onnx_parser.cpp
View file @
f8a75f8a
...
@@ -34,7 +34,9 @@
...
@@ -34,7 +34,9 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <onnx.pb.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -484,6 +486,8 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
SPARSE_TENSOR
:
case
onnx
::
AttributeProto
::
SPARSE_TENSOR
:
case
onnx
::
AttributeProto
::
SPARSE_TENSORS
:
case
onnx
::
AttributeProto
::
SPARSE_TENSORS
:
case
onnx
::
AttributeProto
::
TYPE_PROTOS
:
case
onnx
::
AttributeProto
::
TYPE_PROTO
:
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
MIGRAPHX_THROW
(
"PARSE_VALUE: Invalid attribute type "
+
std
::
to_string
(
attr
.
type
()));
MIGRAPHX_THROW
(
"PARSE_VALUE: Invalid attribute type "
+
std
::
to_string
(
attr
.
type
()));
...
@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
...
@@ -545,6 +549,18 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
FLOAT8E4M3FNUZ
:
{
std
::
vector
<
int32_t
>
data_int32
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
migraphx
::
fp8
::
fp8e4m3fnuz
>
data_fp8
;
std
::
transform
(
data_int32
.
begin
(),
data_int32
.
end
(),
std
::
back_inserter
(
data_fp8
),
[](
float
raw_val
)
{
return
migraphx
::
fp8
::
fp8e4m3fnuz
{
raw_val
};
});
return
create_literal
(
shape
::
fp8e4m3fnuz_type
,
dims
,
data_fp8
);
}
case
onnx
::
TensorProto
::
FLOAT8E5M2FNUZ
:
case
onnx
::
TensorProto
::
FLOAT8E5M2
:
case
onnx
::
TensorProto
::
FLOAT8E4M3FN
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
...
@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
...
@@ -609,6 +625,13 @@ shape::type_t get_type(int dtype)
case
11
:
return
shape
::
double_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
case
13
:
return
shape
::
uint64_type
;
case
18
:
return
shape
::
fp8e4m3fnuz_type
;
case
14
:
case
15
:
case
16
:
case
17
:
case
19
:
case
20
:
default:
{
default:
{
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
}
}
...
...
src/onnx/parse_lstm.cpp
View file @
f8a75f8a
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
}
}
void
lstm_transpose_inputs
(
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
0
]);
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_undefined
())
{
args
[
5
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
5
]);
}
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_undefined
())
{
args
[
6
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
6
]);
}
}
void
lstm_transpose_outputs
(
onnx_parser
::
node_info
&
info
,
instruction_ref
&
hidden_states
,
instruction_ref
&
last_output
,
instruction_ref
&
last_cell_output
)
{
std
::
vector
<
int64_t
>
perm_hs
{
2
,
0
,
1
,
3
};
hidden_states
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hs
}}),
hidden_states
);
std
::
vector
<
int64_t
>
perm_last
{
1
,
0
,
2
};
last_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_output
);
last_cell_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_cell_output
);
}
struct
parse_lstm
:
op_parser
<
parse_lstm
>
struct
parse_lstm
:
op_parser
<
parse_lstm
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
}
int
layout
=
0
;
if
(
contains
(
info
.
attributes
,
"layout"
))
{
layout
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"layout"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
if
(
args
.
size
()
<
8
)
{
{
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
}
if
(
layout
!=
0
)
{
lstm_transpose_inputs
(
info
,
args
);
}
// first output for concatenation of hidden states
// first output for concatenation of hidden states
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto
last_cell_output
=
auto
last_cell_output
=
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
if
(
layout
!=
0
)
{
lstm_transpose_outputs
(
info
,
hidden_states
,
last_output
,
last_cell_output
);
}
return
{
hidden_states
,
last_output
,
last_cell_output
};
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
}
};
};
...
...
src/onnx/parse_multinomial.cpp
View file @
f8a75f8a
...
@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
...
@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
// number of elements is batch_size * sample_size .
size_t
batch_size
=
s0
.
lens
().
front
();
size_t
batch_size
=
s0
.
lens
().
front
();
auto
rand_dummy
=
info
.
add_literal
(
auto
rand_dummy
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
literal
{
migraphx
::
shape
::
float_type
,
{
batch_size
*
sample_size
}}
);
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
sample_size
}}
,
std
::
vector
<
float
>
(
batch_size
*
sample_size
)});
randoms
=
randoms
=
info
.
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
info
.
add_instruction
(
migraphx
::
make_op
(
"random_uniform"
),
seed_input
,
rand_dummy
);
}
}
...
...
Prev
1
2
3
4
5
6
7
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment