Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
25e8cf0b
Unverified
Commit
25e8cf0b
authored
Jan 27, 2023
by
Ted Themistokleous
Committed by
GitHub
Jan 27, 2023
Browse files
Merge branch 'develop' into test_onnx_zoo
parents
a313a68e
635502be
Changes
138
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
698 additions
and
223 deletions
+698
-223
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+71
-7
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+5
-5
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+63
-29
src/include/migraphx/op/transpose.hpp
src/include/migraphx/op/transpose.hpp
+31
-15
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+76
-41
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+13
-0
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+3
-1
src/insert_pad.cpp
src/insert_pad.cpp
+2
-2
src/instruction.cpp
src/instruction.cpp
+18
-0
src/module.cpp
src/module.cpp
+89
-0
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+31
-8
src/onnx/parse_gemm.cpp
src/onnx/parse_gemm.cpp
+56
-39
src/onnx/parse_matmul.cpp
src/onnx/parse_matmul.cpp
+58
-34
src/onnx/parse_pad.cpp
src/onnx/parse_pad.cpp
+6
-0
src/onnx/parse_pooling.cpp
src/onnx/parse_pooling.cpp
+82
-38
src/onnx/parse_reduce_op.cpp
src/onnx/parse_reduce_op.cpp
+1
-2
src/onnx/parse_reshape.cpp
src/onnx/parse_reshape.cpp
+1
-1
src/onnx/parse_transpose.cpp
src/onnx/parse_transpose.cpp
+1
-1
src/onnx/parse_trilu.cpp
src/onnx/parse_trilu.cpp
+90
-0
No files found.
src/include/migraphx/op/reshape.hpp
View file @
25e8cf0b
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -46,14 +47,60 @@ struct reshape
...
@@ -46,14 +47,60 @@ struct reshape
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
value
attributes
()
const
{
return
{{
"require_std_shape"
,
true
}};
}
std
::
string
name
()
const
{
return
"reshape"
;
}
std
::
string
name
()
const
{
return
"reshape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
dyn_compute_shape
(
shape
s0
)
const
{
auto
dyn_dims
=
s0
.
dyn_dims
();
auto
num_not_fixed
=
std
::
count_if
(
dyn_dims
.
cbegin
(),
dyn_dims
.
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
});
if
(
num_not_fixed
!=
1
)
{
MIGRAPHX_THROW
(
"Reshape: Only supports one non-fixed dynamic_dimension"
);
}
// track number of fixed elements in input and output
std
::
size_t
num_dims_ele
=
1
;
std
::
size_t
num_dd_ele
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
dyn_dims
.
size
();
++
i
)
{
if
(
dyn_dims
[
i
].
is_fixed
())
{
num_dims_ele
*=
dims
[
i
];
num_dd_ele
*=
dyn_dims
[
i
].
min
;
}
else
{
if
(
dims
[
i
]
!=
0
and
dims
[
i
]
!=
-
1
)
{
MIGRAPHX_THROW
(
"Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension"
);
}
}
}
if
(
num_dims_ele
!=
num_dd_ele
)
{
MIGRAPHX_THROW
(
"Reshape: Number of fixed elements must match. Input: "
+
std
::
to_string
(
num_dd_ele
)
+
" Output: "
+
std
::
to_string
(
num_dims_ele
));
}
// construct output dynamic shape from dims attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
dims
.
size
());
std
::
transform
(
dims
.
cbegin
(),
dims
.
cend
(),
dyn_dims
.
cbegin
(),
output_dyn_dims
.
begin
(),
[](
std
::
size_t
dim
,
auto
dyn_dim
)
{
if
(
not
dyn_dim
.
is_fixed
())
return
dyn_dim
;
return
shape
::
dynamic_dimension
{
dim
,
dim
};
});
return
{
s0
.
type
(),
output_dyn_dims
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
{
...
@@ -86,9 +133,26 @@ struct reshape
...
@@ -86,9 +133,26 @@ struct reshape
return
s
;
return
s
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPHX_THROW
(
"Reshape: Dimensions for reshape can only have one -1 dim"
);
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
())
{
return
dyn_compute_shape
(
s0
);
}
else
{
return
static_compute_shape
(
inputs
,
n_neg_dims
);
}
}
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
...
...
src/include/migraphx/op/softmax.hpp
View file @
25e8cf0b
...
@@ -53,15 +53,15 @@ struct softmax
...
@@ -53,15 +53,15 @@ struct softmax
std
::
string
name
()
const
{
return
"softmax"
;
}
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
if
(
inputs
.
at
(
0
).
packed
())
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
()
or
s0
.
packed
())
{
{
return
inputs
.
at
(
0
)
;
return
s0
;
}
}
else
else
{
{
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
s0
.
type
(),
s0
.
lens
()};
return
{
inputs
.
at
(
0
).
type
(),
lens
};
}
}
}
}
...
...
src/include/migraphx/op/squeeze.hpp
View file @
25e8cf0b
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -54,52 +55,85 @@ struct squeeze
...
@@ -54,52 +55,85 @@ struct squeeze
std
::
string
name
()
const
{
return
"squeeze"
;
}
std
::
string
name
()
const
{
return
"squeeze"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
if
(
input_shape
.
dynamic
())
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
{
MIGRAPHX_THROW
(
"squeeze axis dimension should be equal to 1"
);
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
}
return
input_shape
.
dyn_dims
()[
axis
]
!=
1
;
std
::
vector
<
std
::
size_t
>
new_lens
;
}))
std
::
vector
<
std
::
size_t
>
new_strides
;
{
if
(
axes
.
empty
())
MIGRAPHX_THROW
(
{
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}"
);
for
(
auto
i
:
range
(
old_lens
.
size
()))
}
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
if
(
axes
.
empty
())
{
std
::
copy_if
(
input_shape
.
dyn_dims
().
cbegin
(),
input_shape
.
dyn_dims
().
cend
(),
std
::
back_inserter
(
dyn_dims
),
[
&
](
auto
dd
)
{
return
dd
!=
1
;
});
}
else
{
{
if
(
old_lens
[
i
]
!=
1
)
for
(
auto
i
:
range
(
input_shape
.
ndim
())
)
{
{
new_lens
.
push_back
(
old_lens
[
i
]);
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
new_strides
.
push_back
(
old_strides
[
i
]);
{
dyn_dims
.
push_back
(
input_shape
.
dyn_dims
()[
i
]);
}
}
}
}
}
return
{
input_shape
.
type
(),
dyn_dims
};
}
}
else
else
{
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
std
::
any_of
(
axes
.
begin
(),
axes
.
end
(),
[
&
](
auto
axis
)
{
return
old_lens
[
axis
]
!=
1
;
}))
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
MIGRAPHX_THROW
(
"SQUEEZE: static axis dimension should be equal to 1"
);
}
std
::
vector
<
std
::
size_t
>
new_lens
;
std
::
vector
<
std
::
size_t
>
new_strides
;
if
(
axes
.
empty
())
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
{
{
new_lens
.
push_back
(
old_lens
[
i
]);
if
(
old_lens
[
i
]
!=
1
)
new_strides
.
push_back
(
old_strides
[
i
]);
{
new_lens
.
push_back
(
old_lens
[
i
]);
new_strides
.
push_back
(
old_strides
[
i
]);
}
}
}
}
}
}
else
if
(
new_lens
.
empty
())
{
{
for
(
auto
i
:
range
(
old_lens
.
size
()))
return
shape
{
type
};
{
}
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
else
{
{
new_lens
.
push_back
(
old_lens
[
i
]);
return
shape
{
type
,
new_lens
,
new_strides
};
new_strides
.
push_back
(
old_strides
[
i
]);
}
}
}
if
(
new_lens
.
empty
())
{
return
shape
{
type
};
}
else
{
return
shape
{
type
,
new_lens
,
new_strides
};
}
}
}
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/transpose.hpp
View file @
25e8cf0b
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -45,17 +46,15 @@ struct transpose
...
@@ -45,17 +46,15 @@ struct transpose
}
}
std
::
string
name
()
const
{
return
"transpose"
;
}
std
::
string
name
()
const
{
return
"transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
auto
t
=
input
.
type
();
if
(
dims
.
size
()
!=
input
_lens
.
size
())
if
(
dims
.
size
()
!=
input
.
ndim
())
{
{
MIGRAPHX_THROW
(
"Permutation has wrong number of axes"
);
MIGRAPHX_THROW
(
"
TRANSPOSE:
Permutation has wrong number of axes"
);
}
}
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
vector
<
int64_t
>
axes
(
dims
.
size
());
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
...
@@ -63,19 +62,36 @@ struct transpose
...
@@ -63,19 +62,36 @@ struct transpose
{
{
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
MIGRAPHX_THROW
(
"TRANSPOSE: Invalid permutation"
);
}
}
std
::
vector
<
size_t
>
output_lens
(
input_lens
.
size
());
std
::
vector
<
size_t
>
output_strides
(
input_lens
.
size
());
if
(
input
.
dynamic
())
for
(
std
::
size_t
i
=
0
;
i
<
output_lens
.
size
();
i
++
)
{
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
(
input
.
ndim
());
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
std
::
transform
(
dims
.
cbegin
(),
dims
.
cend
(),
output_dyn_dims
.
begin
(),
[
&
](
auto
dim
)
{
return
input
.
dyn_dims
()[
dim
];
});
return
{
input
.
type
(),
output_dyn_dims
};
}
else
{
auto
input_lens
=
input
.
lens
();
auto
input_strides
=
input
.
strides
();
std
::
vector
<
size_t
>
output_lens
(
input
.
ndim
());
std
::
vector
<
size_t
>
output_strides
(
input
.
ndim
());
for
(
std
::
size_t
i
=
0
;
i
<
input
.
ndim
();
i
++
)
{
output_lens
[
i
]
=
input_lens
[
dims
[
i
]];
output_strides
[
i
]
=
input_strides
[
dims
[
i
]];
}
return
{
input
.
type
(),
output_lens
,
output_strides
};
}
}
return
{
t
,
output_lens
,
output_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
25e8cf0b
...
@@ -29,11 +29,20 @@
...
@@ -29,11 +29,20 @@
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
...
@@ -56,63 +65,89 @@ struct unsqueeze
...
@@ -56,63 +65,89 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
if
(
input_shape
.
dynamic
())
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
{
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
if
(
not
steps
.
empty
())
return
shape
{
type
,
old_lens
};
{
else
MIGRAPHX_THROW
(
"UNSQUEEZE_dyn: nonempty steps attribute"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
std
::
vector
<
shape
::
dynamic_dimension
>
dyn_dims
=
{};
auto
new_ndim
=
input_shape
.
ndim
()
+
axes
.
size
();
std
::
size_t
k
=
0
;
for
(
auto
i
:
range
(
new_ndim
))
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
dyn_dims
.
push_back
({
1
,
1
,
0
});
}
else
{
dyn_dims
.
push_back
(
input_shape
.
dyn_dims
().
at
(
k
++
));
}
}
return
{
input_shape
.
type
(),
dyn_dims
};
}
}
else
{
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_strides
=
input_shape
.
strides
();
if
(
input_shape
.
scalar
())
{
if
(
old_lens
.
size
()
==
1
and
old_lens
.
front
()
==
1
)
return
shape
{
type
,
old_lens
};
else
MIGRAPHX_THROW
(
"UNSQUEEZE: Input must be a scalar"
);
}
if
(
steps
.
size
()
>
axes
.
size
())
if
(
steps
.
size
()
>
axes
.
size
())
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
MIGRAPHX_THROW
(
"UNSQUEEZE: Steps provided with no axis"
);
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_strides
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_strides
(
new_size
);
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
for
(
auto
i
:
range
(
new_size
))
{
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
{
std
::
int64_t
step
=
1
;
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
steps
.
size
())
if
(
axis_idx
<
axes
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
std
::
int64_t
step
=
1
;
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
if
(
axis_idx
<
steps
.
size
())
old_lens
[
p
]
/=
step
;
step
=
steps
[
axis_idx
];
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
else
{
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
}
else
else
{
{
if
(
step
!=
1
)
new_lens
[
i
]
=
old_lens
[
p
];
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
old_strides
[
p
++
];
new_strides
[
i
]
=
1
;
}
}
}
}
else
return
shape
{
type
,
new_lens
,
new_strides
};
{
new_lens
[
i
]
=
old_lens
[
p
];
new_strides
[
i
]
=
old_strides
[
p
++
];
}
}
}
return
shape
{
type
,
new_lens
,
new_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
src/include/migraphx/program.hpp
View file @
25e8cf0b
...
@@ -115,6 +115,7 @@ struct program
...
@@ -115,6 +115,7 @@ struct program
print_func
)
const
;
print_func
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_graph
(
std
::
ostream
&
os
,
bool
brief
=
false
)
const
;
void
print_py
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
print_cpp
(
std
::
ostream
&
os
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
void
dry_run
(
parameter_map
params
)
const
;
...
...
src/include/migraphx/shape.hpp
View file @
25e8cf0b
...
@@ -101,6 +101,19 @@ struct shape
...
@@ -101,6 +101,19 @@ struct shape
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
dynamic_dimension
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
dynamic_dimension
&
x
);
// compare to fixed std::size_t dimension
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
==
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
bool
operator
!=
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
bool
operator
!=
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
// add and subtract fixed std::size_t dimension
dynamic_dimension
&
operator
+=
(
const
std
::
size_t
&
x
);
dynamic_dimension
&
operator
-=
(
const
std
::
size_t
&
x
);
friend
dynamic_dimension
operator
+
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
friend
dynamic_dimension
operator
+
(
const
std
::
size_t
&
x
,
const
dynamic_dimension
&
y
);
friend
dynamic_dimension
operator
-
(
const
dynamic_dimension
&
x
,
const
std
::
size_t
&
y
);
};
};
static
const
std
::
vector
<
type_t
>&
types
();
static
const
std
::
vector
<
type_t
>&
types
();
...
...
src/include/migraphx/shape_for_each.hpp
View file @
25e8cf0b
...
@@ -31,6 +31,9 @@
...
@@ -31,6 +31,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
/**
* Iterates the given function over the indices from the shape in order.
*/
template
<
class
F
>
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
{
{
...
@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
...
@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call
(
indices
);
call
(
indices
);
}
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/insert_pad.cpp
View file @
25e8cf0b
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
...
@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{
{
return
;
return
;
}
}
auto
kdims
=
input
->
get_shape
().
lens
().
size
()
-
2
;
auto
kdims
=
input
->
get_shape
().
ndim
()
-
2
;
if
(
std
::
equal
(
op
.
padding
.
begin
(),
if
(
std
::
equal
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
()))
op
.
padding
.
end
()))
return
;
return
;
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
lens
().
size
()
*
2
,
0
);
std
::
vector
<
int64_t
>
padding
(
input
->
get_shape
().
ndim
()
*
2
,
0
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
std
::
vector
<
size_t
>
pads_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
op
.
padding
=
std
::
vector
<
size_t
>
(
kdims
*
2
,
0
);
...
...
src/instruction.cpp
View file @
25e8cf0b
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
...
@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
std
::
replace
(
module_args
.
begin
(),
module_args
.
end
(),
old
,
new_mod
);
}
}
bool
instruction
::
is_undefined
()
const
{
if
(
op
.
name
()
==
"undefined"
)
{
return
true
;
}
else
if
(
this
->
inputs
().
empty
())
{
return
false
;
}
else
{
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
is_undefined
();
});
}
}
bool
instruction
::
can_eval
()
const
bool
instruction
::
can_eval
()
const
{
{
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
)
...
...
src/module.cpp
View file @
25e8cf0b
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
"x_"
+
replace_string
(
name
,
":"
,
"_module_"
));
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
os
<<
"migraphx.op("
<<
enclose_name
(
op
.
name
());
auto
default_values
=
make_op
(
op
.
name
()).
to_value
();
for
(
auto
&&
x
:
v
)
{
auto
name
=
x
.
get_key
();
if
(
default_values
[
name
]
==
x
)
continue
;
os
<<
", "
<<
name
<<
"="
<<
to_json_string
(
x
.
without_key
());
}
os
<<
")"
;
}
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_make_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
...
@@ -804,6 +820,15 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -804,6 +820,15 @@ static void print_make_op(std::ostream& os, const operation& op)
os
<<
")"
;
os
<<
")"
;
}
}
static
void
print_py_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
os
<<
"migraphx.shape(type="
<<
to_json_string
(
s
.
type_string
())
<<
", lens="
<<
to_json_string
(
s
.
lens
());
if
(
not
s
.
standard
())
os
<<
", strides="
<<
to_json_string
(
s
.
strides
());
os
<<
")"
;
}
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
static
void
print_cpp_shape
(
std
::
ostream
&
os
,
const
migraphx
::
shape
&
s
)
{
{
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
os
<<
"migraphx::shape{migraphx::shape::"
<<
s
.
type_string
();
...
@@ -813,6 +838,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
...
@@ -813,6 +838,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os
<<
"}"
;
os
<<
"}"
;
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_py
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
)
const
{
// cppcheck-suppress variableScope
unsigned
long
seed
=
names
.
size
();
auto
last
=
std
::
prev
(
this
->
end
());
names
=
this
->
print
(
[
&
](
auto
ins
,
auto
ins_names
)
{
std
::
vector
<
std
::
string
>
input_vars
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
input_vars
),
[
&
](
auto
input
)
{
return
cpp_var_name
(
ins_names
.
at
(
input
));
});
if
(
ins
!=
last
)
os
<<
cpp_var_name
(
ins_names
.
at
(
ins
))
<<
" = "
;
if
(
ins
->
name
()
==
"@literal"
)
{
os
<<
mname
<<
".add_literal("
;
bool
use_abs
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
v
)
{
use_abs
=
std
::
none_of
(
v
.
begin
(),
v
.
end
(),
[](
auto
x
)
{
return
x
<
0
;
});
});
// Disable abs for now
use_abs
=
false
;
if
(
use_abs
)
os
<<
"migraphx.abs_literal("
;
os
<<
"migraphx.generate_literal("
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
", "
<<
seed
<<
")"
;
if
(
use_abs
)
os
<<
")"
;
os
<<
")"
<<
std
::
endl
;
seed
++
;
}
else
if
(
ins
->
name
()
==
"@param"
)
{
std
::
string
name
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
parameter
;
os
<<
mname
<<
".add_parameter("
<<
enclose_name
(
name
)
<<
","
;
print_py_shape
(
os
,
ins
->
get_shape
());
os
<<
")"
<<
std
::
endl
;
}
else
if
(
ins
->
name
()
==
"@return"
)
{
os
<<
mname
<<
".add_return(["
<<
join_strings
(
input_vars
,
", "
)
<<
"])"
<<
std
::
endl
;
}
else
{
assert
(
ins
->
name
().
front
()
!=
'@'
);
os
<<
mname
<<
".add_instruction("
;
print_py_op
(
os
,
ins
->
get_operator
());
os
<<
", ["
<<
join_strings
(
input_vars
,
", "
)
<<
"]"
;
os
<<
")"
<<
std
::
endl
;
}
},
names
);
return
names
;
}
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
module
::
print_cpp
(
std
::
ostream
&
os
,
module
::
print_cpp
(
std
::
ostream
&
os
,
const
std
::
string
&
mname
,
const
std
::
string
&
mname
,
...
@@ -874,6 +961,8 @@ module::print_cpp(std::ostream& os,
...
@@ -874,6 +961,8 @@ module::print_cpp(std::ostream& os,
return
names
;
return
names
;
}
}
void
module
::
print_py
(
std
::
ostream
&
os
)
const
{
this
->
print_py
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
print_cpp
(
std
::
ostream
&
os
)
const
{
this
->
print_cpp
(
os
,
this
->
name
(),
{});
}
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
void
module
::
annotate
(
std
::
ostream
&
os
,
std
::
function
<
void
(
instruction_ref
)
>
a
)
const
...
...
src/onnx/onnx_parser.cpp
View file @
25e8cf0b
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
...
@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
{
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
auto
bias_bcast
=
mod
->
add_instruction
(
instruction_ref
bias_bcast
;
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
// if curr_ins has a dynamic output shape use 2 input broadcast
args
[
2
]);
if
(
curr_ins
->
get_shape
().
dynamic
())
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
}}),
args
[
2
],
curr_ins
);
}
else
{
bias_bcast
=
mod
->
add_instruction
(
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
curr_ins
->
get_shape
().
lens
()}}),
args
[
2
]);
}
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
return
mod
->
add_instruction
(
make_op
(
"add"
),
curr_ins
,
bias_bcast
);
}
}
return
curr_ins
;
return
curr_ins
;
...
@@ -393,18 +403,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...
@@ -393,18 +403,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
literal
onnx_parser
::
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
const
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
not
t
.
external_data
().
empty
())
auto
type
=
get_type
(
t
.
data_type
());
shape
tensor_shape
(
type
,
dims
);
auto
external_data
=
t
.
external_data
();
if
(
not
external_data
.
empty
())
{
{
const
std
::
string
&
data_file
=
t
.
external_data
().
at
(
0
).
value
();
const
std
::
string
&
data_file
=
external_data
.
at
(
0
).
value
();
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
);
size_t
num_data_fields
=
external_data
.
size
();
size_t
offset
=
0
;
size_t
nbytes
=
tensor_shape
.
bytes
();
if
(
num_data_fields
>
1
)
// if offset field is present
{
offset
=
std
::
stoul
(
t
.
external_data
().
at
(
1
).
value
());
}
if
(
num_data_fields
>
2
)
// if nbytes field is present
{
nbytes
=
std
::
stoul
(
t
.
external_data
().
at
(
2
).
value
());
}
auto
raw_buffer
=
read_buffer
(
path
+
"/"
+
data_file
,
offset
,
nbytes
);
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
std
::
string
s
(
raw_buffer
.
begin
(),
raw_buffer
.
end
());
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
auto
type
=
get_type
(
t
.
data_type
());
return
create_literal
(
type
,
dims
,
s
.
data
());
return
create_literal
(
type
,
dims
,
s
.
data
());
}
}
...
...
src/onnx/parse_gemm.cpp
View file @
25e8cf0b
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
float
alpha
=
1.0
f
;
auto
a_arg
=
args
[
0
];
float
beta
=
1.0
f
;
auto
b_arg
=
args
[
1
];
bool
transa
=
false
;
if
(
a_arg
->
get_shape
().
ndim
()
!=
2
or
b_arg
->
get_shape
().
ndim
()
!=
2
)
bool
transb
=
false
;
{
MIGRAPHX_THROW
(
"PARSE_GEMM: A and B should be rank 2, A is rank "
+
std
::
to_string
(
a_arg
->
get_shape
().
ndim
())
+
", B is rank "
+
std
::
to_string
(
b_arg
->
get_shape
().
ndim
()));
}
float
alpha
=
1.0
f
;
float
beta
=
1.0
f
;
bool
trans_a
=
false
;
bool
trans_b
=
false
;
if
(
contains
(
info
.
attributes
,
"alpha"
))
if
(
contains
(
info
.
attributes
,
"alpha"
))
{
{
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
alpha
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
...
@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
...
@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
if
(
contains
(
info
.
attributes
,
"transA"
))
if
(
contains
(
info
.
attributes
,
"transA"
))
{
{
transa
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
trans
_
a
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
}
if
(
contains
(
info
.
attributes
,
"transB"
))
if
(
contains
(
info
.
attributes
,
"transB"
))
{
{
transb
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
trans
_
b
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
}
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
int64_t
{
0
});
auto
dot_type
=
a_arg
->
get_shape
().
type
();
// swap the last two elements
std
::
swap
(
*
perm
.
rbegin
(),
*
(
perm
.
rbegin
()
+
1
));
auto
l1
=
args
[
0
];
auto
dot_type
=
l1
->
get_shape
().
type
();
if
(
alpha
!=
1.0
f
)
if
(
alpha
!=
1.0
f
)
{
{
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
auto
alpha_literal
=
info
.
add_literal
(
alpha
);
l1
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
l1
);
a_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
alpha_literal
,
a_arg
);
if
(
l1
->
get_shape
().
type
()
!=
dot_type
)
if
(
a_arg
->
get_shape
().
type
()
!=
dot_type
)
{
{
l1
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
l1
);
a_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
a_arg
);
}
}
}
}
l1
=
a_arg
=
(
trans_a
)
(
transa
)
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
l1
)
:
l1
;
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
a_arg
)
auto
l2
=
(
transb
)
:
a_arg
;
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
b_arg
=
(
trans_b
)
:
args
[
1
];
?
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
1
])
:
args
[
1
];
auto
ret
=
info
.
add_instruction
(
make_op
(
"dot"
),
l1
,
l2
);
auto
dot_ins
=
info
.
add_instruction
(
make_op
(
"dot"
),
a_arg
,
b_arg
);
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
if
(
not
float_equal
(
beta
,
0.0
f
)
&&
args
[
2
]
->
get_shape
().
elements
()
>
0
)
if
(
not
float_equal
(
beta
,
0.0
f
))
{
{
auto
out_lens
=
l1
->
get_shape
().
lens
();
auto
c_arg
=
args
[
2
];
out_lens
.
back
()
=
l2
->
get_shape
().
lens
().
back
();
if
(
dot_ins
->
get_shape
().
dynamic
())
auto
l3
=
args
[
2
];
auto
l3_lens
=
l3
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
l3_lens
.
begin
(),
l3_lens
.
end
()))
{
{
l3
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
),
args
[
2
],
dot_ins
);
args
[
2
]);
}
}
auto
beta_literal
=
info
.
add_literal
(
beta
);
else
auto
beta_l3
=
info
.
add_broadcastable_binary_op
(
"mul"
,
l3
,
beta_literal
);
if
(
beta_l3
->
get_shape
().
type
()
!=
dot_type
)
{
{
beta_l3
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
auto
out_lens
=
a_arg
->
get_shape
().
lens
();
beta_l3
);
out_lens
.
back
()
=
b_arg
->
get_shape
().
lens
().
back
();
auto
c_lens
=
c_arg
->
get_shape
().
lens
();
if
(
not
std
::
equal
(
out_lens
.
begin
(),
out_lens
.
end
(),
c_lens
.
begin
(),
c_lens
.
end
()))
{
c_arg
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
args
[
2
]);
}
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
ret
,
beta_l3
);
if
(
not
float_equal
(
beta
,
1.0
f
))
{
auto
beta_literal
=
info
.
add_literal
(
beta
);
c_arg
=
info
.
add_broadcastable_binary_op
(
"mul"
,
c_arg
,
beta_literal
);
if
(
c_arg
->
get_shape
().
type
()
!=
dot_type
)
{
c_arg
=
info
.
add_instruction
(
make_op
(
"convert"
,
{{
"target_type"
,
dot_type
}}),
c_arg
);
}
}
return
info
.
add_instruction
(
make_op
(
"add"
),
dot_ins
,
c_arg
);
}
}
}
}
return
dot_ins
;
return
ret
;
}
}
};
};
...
...
src/onnx/parse_matmul.cpp
View file @
25e8cf0b
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
...
@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const
onnx_parser
::
node_info
&
info
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
std
::
vector
<
instruction_ref
>
args
)
const
{
{
auto
l0
=
args
[
0
];
auto
a0
=
args
[
0
];
auto
l1
=
args
[
1
];
auto
a1
=
args
[
1
];
auto
l0_len
s
=
l
0
->
get_shape
()
.
lens
()
;
auto
s
0
=
a
0
->
get_shape
();
auto
l1_len
s
=
l
1
->
get_shape
()
.
lens
()
;
auto
s
1
=
a
1
->
get_shape
();
// args[0] is a vector, prepend 1 to the shape
instruction_ref
dot_res
;
bool
is_a_prepended
=
false
;
bool
is_a_prepended
=
false
;
if
(
l0_lens
.
size
()
==
1
)
bool
is_b_appended
=
false
;
if
(
s0
.
ndim
()
==
1
)
{
{
is_a_prepended
=
true
;
is_a_prepended
=
true
;
l0_lens
.
insert
(
l0_lens
.
begin
(),
1
);
a0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
l0
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
}}}),
args
[
0
]);
}
}
if
(
s1
.
ndim
()
==
1
)
bool
is_b_appended
=
false
;
if
(
l1_lens
.
size
()
==
1
)
{
{
is_b_appended
=
true
;
is_b_appended
=
true
;
l1_lens
.
push_back
(
1
);
a1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
l1
=
info
.
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
1
}}}),
args
[
1
]);
}
}
instruction_ref
bl0
=
l0
;
if
(
s0
.
dynamic
()
or
s1
.
dynamic
())
instruction_ref
bl1
=
l1
;
if
(
not
std
::
equal
(
l0_lens
.
rbegin
()
+
2
,
l0_lens
.
rend
(),
l1_lens
.
rbegin
()
+
2
,
l1_lens
.
rend
()))
{
{
auto
l0_it
=
l0_lens
.
begin
()
+
l0_lens
.
size
()
-
2
;
if
(
opd
.
op_name
==
"quant_dot"
)
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
l0_lens
.
begin
(),
l0_it
);
{
auto
l1_it
=
l1_lens
.
begin
()
+
l1_lens
.
size
()
-
2
;
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic MatMulInteger not supported"
)
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
l1_lens
.
begin
(),
l1_it
);
}
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
auto
s0_dds
=
a0
->
get_shape
().
to_dynamic
().
dyn_dims
(
);
l0_broadcasted_lens
=
output_lens
;
auto
s1_dds
=
a1
->
get_shape
().
to_dynamic
().
dyn_dims
()
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
l0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
// TODO: handling this case requires a new multibroadcast mode
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
l1_lens
.
end
());
if
(
not
std
::
equal
(
if
(
l0_lens
!=
l0_broadcasted_lens
)
s0_dds
.
rbegin
()
+
2
,
s0_dds
.
rend
(),
s1_dds
.
rbegin
()
+
2
,
s1_dds
.
rend
())
)
{
{
bl0
=
info
.
add_instruction
(
MIGRAPHX_THROW
(
"PARSE_MATMUL: dynamic shape broadcasting not supported"
);
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
l0
);
}
}
if
(
l1_lens
!=
l1_broadcasted_lens
)
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
a0
,
a1
);
}
else
{
auto
s0_lens
=
a0
->
get_shape
().
lens
();
auto
s1_lens
=
a1
->
get_shape
().
lens
();
instruction_ref
ba0
=
a0
;
instruction_ref
ba1
=
a1
;
// try broadcasting if dimensions other than last two do not match
if
(
not
std
::
equal
(
s0_lens
.
rbegin
()
+
2
,
s0_lens
.
rend
(),
s1_lens
.
rbegin
()
+
2
,
s1_lens
.
rend
()))
{
{
bl1
=
info
.
add_instruction
(
auto
l0_it
=
s0_lens
.
begin
()
+
s0_lens
.
size
()
-
2
;
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
l1
);
std
::
vector
<
std
::
size_t
>
l0_broadcasted_lens
(
s0_lens
.
begin
(),
l0_it
);
auto
l1_it
=
s1_lens
.
begin
()
+
s1_lens
.
size
()
-
2
;
std
::
vector
<
std
::
size_t
>
l1_broadcasted_lens
(
s1_lens
.
begin
(),
l1_it
);
auto
output_lens
=
compute_broadcasted_lens
(
l0_broadcasted_lens
,
l1_broadcasted_lens
);
l0_broadcasted_lens
=
output_lens
;
l0_broadcasted_lens
.
insert
(
l0_broadcasted_lens
.
end
(),
l0_it
,
s0_lens
.
end
());
l1_broadcasted_lens
=
output_lens
;
l1_broadcasted_lens
.
insert
(
l1_broadcasted_lens
.
end
(),
l1_it
,
s1_lens
.
end
());
if
(
s0_lens
!=
l0_broadcasted_lens
)
{
ba0
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l0_broadcasted_lens
}}),
a0
);
}
if
(
s1_lens
!=
l1_broadcasted_lens
)
{
ba1
=
info
.
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
l1_broadcasted_lens
}}),
a1
);
}
}
}
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
ba0
,
ba1
);
}
}
instruction_ref
dot_res
=
info
.
add_instruction
(
make_op
(
opd
.
op_name
),
bl0
,
bl1
);
int64_t
num_axis
=
static_cast
<
int64_t
>
(
dot_res
->
get_shape
().
lens
().
size
());
// squeeze the appended or prepended dimensions
int64_t
num_axis
=
dot_res
->
get_shape
().
ndim
();
if
(
is_a_prepended
)
if
(
is_a_prepended
)
{
{
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
dot_res
=
info
.
add_instruction
(
make_op
(
"squeeze"
,
{{
"axes"
,
{
num_axis
-
2
}}}),
dot_res
);
...
...
src/onnx/parse_pad.cpp
View file @
25e8cf0b
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
...
@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{
{
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
if
(
mode
==
"reflect"
)
if
(
mode
==
"reflect"
)
{
if
(
args
.
front
()
->
get_shape
().
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_PAD: reflect padding with dynamic shape not supported"
);
}
return
reflect_pad
(
info
,
pads
,
args
.
front
());
return
reflect_pad
(
info
,
pads
,
args
.
front
());
}
if
(
mode
!=
"constant"
)
if
(
mode
!=
"constant"
)
{
{
MIGRAPHX_THROW
(
MIGRAPHX_THROW
(
...
...
src/onnx/parse_pooling.cpp
View file @
25e8cf0b
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{
"GlobalLpPool"
,
"lpnorm"
}};
{
"GlobalLpPool"
,
"lpnorm"
}};
}
}
instruction_ref
parse
(
const
op_desc
&
opd
,
value
handle_values
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
shape
&
in_shape
,
std
::
vector
<
instruction_ref
>
arg
s
)
const
value
value
s
)
const
{
{
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
auto
kdims
=
in_shape
.
ndim
()
-
2
;
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
std
::
string
mode
=
opd
.
op_name
;
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_lens
=
l0
->
get_shape
().
lens
();
assert
(
in_lens
.
size
()
>
2
);
auto
kdims
=
in_lens
.
size
()
-
2
;
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
if
(
starts_with
(
opd
.
onnx_name
,
"Global"
))
{
{
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
// if spatial dimensions are dynamic use dyn_global flag
if
(
in_shape
.
dynamic
()
and
std
::
any_of
(
in_shape
.
dyn_dims
().
cbegin
()
+
2
,
in_shape
.
dyn_dims
().
cend
(),
[](
auto
dd
)
{
return
not
dd
.
is_fixed
();
}))
{
values
[
"dyn_global"
]
=
true
;
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
();
}
else
{
// works with static and fixed dynamic shape
auto
m_lens
=
in_shape
.
max_lens
();
values
[
"lengths"
]
=
std
::
vector
<
size_t
>
(
m_lens
.
begin
()
+
2
,
m_lens
.
end
());
}
}
}
// does not support ceil_mode
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
if
(
contains
(
info
.
attributes
,
"ceil_mode"
))
{
{
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
values
[
"ceil_mode"
]
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"ceil_mode"
).
i
());
}
}
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
if
(
contains
(
info
.
attributes
,
"strides"
))
if
(
contains
(
info
.
attributes
,
"strides"
))
{
{
values
[
"stride"
].
clear
();
values
[
"stride"
].
clear
();
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
copy
(
info
.
attributes
[
"strides"
].
ints
(),
std
::
back_inserter
(
values
[
"stride"
]));
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
check_attr_sizes
(
kdims
,
values
[
"stride"
].
size
(),
"PARSE_POOLING: inconsistent strides"
);
}
}
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
if
(
contains
(
info
.
attributes
,
"kernel_shape"
))
{
{
values
[
"lengths"
].
clear
();
values
[
"lengths"
].
clear
();
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET"
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode
(
info
,
"POOLING"
);
check_padding_mode
(
info
,
"POOLING"
);
return
values
;
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
std
::
string
mode
=
opd
.
op_name
;
const
std
::
unordered_map
<
std
::
string
,
op
::
pooling_mode
>
mode_map
=
{
{
"max"
,
op
::
pooling_mode
::
max
},
{
"average"
,
op
::
pooling_mode
::
average
},
{
"lpnorm"
,
op
::
pooling_mode
::
lpnorm
}};
if
(
not
contains
(
mode_map
,
mode
))
{
MIGRAPHX_THROW
(
"PARSE_POOLING: onnx pooling mode must be [
\"
max
\"
,
\"
average
\"
,
\"
lpnorm
\"
]"
);
}
operation
op
=
make_op
(
"pooling"
,
{{
"mode"
,
mode_map
.
at
(
mode
)}});
value
values
=
op
.
to_value
();
auto
l0
=
args
[
0
];
auto
in_shape
=
l0
->
get_shape
();
assert
(
in_shape
.
ndim
()
>
2
);
auto
kdims
=
in_shape
.
ndim
()
-
2
;
values
=
handle_values
(
opd
,
info
,
in_shape
,
values
);
// count include padding, if count include pad is 1, we always use
// explicit pad
int
count_include_pad
=
0
;
if
(
contains
(
info
.
attributes
,
"count_include_pad"
))
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape"
);
}
count_include_pad
=
info
.
attributes
.
at
(
"count_include_pad"
).
i
();
}
std
::
vector
<
int64_t
>
paddings
;
std
::
vector
<
int64_t
>
paddings
;
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
float
pad_val
=
((
mode
==
"max"
)
?
std
::
numeric_limits
<
float
>::
lowest
()
:
0.0
f
);
...
@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -123,14 +153,22 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
if
(
contains
(
info
.
attributes
,
"auto_pad"
))
{
{
values
[
"padding"
].
clear
();
if
(
in_shape
.
dynamic
())
// return paddings could be empty, then setting to 0 for no padding
{
cal_auto_padding_size
(
info
,
MIGRAPHX_THROW
(
values
,
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"
);
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
}
{
1
,
1
},
else
in_lens
,
{
paddings
);
values
[
"padding"
].
clear
();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size
(
info
,
values
,
values
[
"lengths"
].
to_vector
<
std
::
size_t
>
(),
{
1
,
1
},
in_shape
.
lens
(),
paddings
);
}
}
}
if
(
paddings
.
size
()
!=
2
*
kdims
)
if
(
paddings
.
size
()
!=
2
*
kdims
)
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values
[
"stride"
].
resize
(
kdims
);
values
[
"stride"
].
resize
(
kdims
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
std
::
fill_n
(
values
[
"stride"
].
begin
(),
kdims
,
1
);
}
}
// used to calculate the supposed output shape
// used to calculate the supposed output shape
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
std
::
vector
<
int64_t
>
orig_padding
=
paddings
;
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
...
@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if
(
not
slice_start
.
empty
())
if
(
not
slice_start
.
empty
())
{
{
if
(
in_shape
.
dynamic
())
{
MIGRAPHX_THROW
(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape"
);
}
// calculate expected output shape
// calculate expected output shape
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
()
+
kdims
,
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
orig_padding
.
insert
(
orig_padding
.
begin
(),
2
,
0
);
...
...
src/onnx/parse_reduce_op.cpp
View file @
25e8cf0b
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
...
@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
}
}
else
else
{
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
axes
.
resize
(
args
.
front
()
->
get_shape
().
ndim
());
axes
.
resize
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
}
}
}
}
...
...
src/onnx/parse_reshape.cpp
View file @
25e8cf0b
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
...
@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if
(
args
.
size
()
==
2
)
if
(
args
.
size
()
==
2
)
{
{
auto
s
=
args
[
1
]
->
eval
();
auto
s
=
args
[
1
]
->
eval
();
check_arg_empty
(
s
,
"Reshape:
dynamic shape
is not supported"
);
check_arg_empty
(
s
,
"Reshape:
non-constant shape input
is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
}
}
...
...
src/onnx/parse_transpose.cpp
View file @
25e8cf0b
...
@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
...
@@ -47,7 +47,7 @@ struct parse_transpose : op_parser<parse_transpose>
}
}
// if perm is empty, use the default value
// if perm is empty, use the default value
auto
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
auto
n_dim
=
args
.
front
()
->
get_shape
().
ndim
();
if
(
perm
.
empty
())
if
(
perm
.
empty
())
{
{
perm
.
resize
(
n_dim
);
perm
.
resize
(
n_dim
);
...
...
src/onnx/parse_trilu.cpp
0 → 100644
View file @
25e8cf0b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_trilu
:
op_parser
<
parse_trilu
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Trilu"
}};
}
instruction_ref
parse
(
const
op_desc
&
,
const
onnx_parser
&
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
input_shape
=
args
[
0
]
->
get_shape
();
assert
(
input_shape
.
ndim
()
>=
2
);
auto
input_lens
=
input_shape
.
lens
();
size_t
num_rows
=
*
(
input_lens
.
rbegin
()
+
1
);
size_t
num_cols
=
input_lens
.
back
();
int
k
=
0
;
bool
upper
=
true
;
if
(
args
.
size
()
>
1
)
{
auto
arg_k
=
args
[
1
]
->
eval
();
check_arg_empty
(
arg_k
,
"PARSE_TRILU: dynamic k not supported"
);
k
=
arg_k
.
at
<
int
>
();
}
if
(
k
<
0
)
MIGRAPHX_THROW
(
"PARSE_TRILU: negative k values not supported"
);
if
(
contains
(
info
.
attributes
,
"upper"
))
{
upper
=
static_cast
<
bool
>
(
info
.
attributes
.
at
(
"upper"
).
i
());
}
shape
::
type_t
output_type
=
args
[
0
]
->
get_shape
().
type
();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std
::
vector
<
bool
>
mask_mat
(
num_rows
*
num_cols
,
upper
);
for
(
size_t
i
=
0
;
i
<
num_rows
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
std
::
min
(
k
,
static_cast
<
int
>
(
num_cols
));
j
++
)
{
mask_mat
[
i
*
num_cols
+
j
]
=
not
upper
;
}
k
++
;
}
auto
mask
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
output_type
,
{
num_rows
,
num_cols
}},
mask_mat
});
return
info
.
add_broadcastable_binary_op
(
"mul"
,
mask
,
args
[
0
]);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment