Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
60a0f286
Commit
60a0f286
authored
Feb 22, 2022
by
Shucai Xiao
Browse files
changes to be able to parse the simple pytorch model
parent
2629e8f1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
47 additions
and
18 deletions
+47
-18
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+2
-2
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+4
-3
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+8
-0
src/include/migraphx/op/shape_op.hpp
src/include/migraphx/op/shape_op.hpp
+1
-1
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-0
src/instruction.cpp
src/instruction.cpp
+9
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+13
-10
src/shape.cpp
src/shape.cpp
+7
-0
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
+1
-1
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
60a0f286
...
@@ -59,8 +59,8 @@ struct broadcast
...
@@ -59,8 +59,8 @@ struct broadcast
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
input
.
elements
())
//
if(output.elements() < input.elements())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to input size"
);
//
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return
output
;
return
output
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/convolution.hpp
View file @
60a0f286
...
@@ -55,7 +55,8 @@ struct convolution
...
@@ -55,7 +55,8 @@ struct convolution
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
min_ndims
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
same_type
().
same_ndims
().
min_ndims
(
3
);
check_attribute_size
();
check_attribute_size
();
// dim num of input and attribute should match
// dim num of input and attribute should match
auto
input_size
=
inputs
[
0
].
lens
().
size
();
auto
in_lens
=
inputs
[
0
].
lens
();
auto
input_size
=
in_lens
.
size
();
auto
padding_size
=
padding
.
size
();
auto
padding_size
=
padding
.
size
();
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
if
(
not
(
input_size
==
padding_size
/
2
+
2
or
input_size
==
padding_size
+
2
))
{
{
...
@@ -73,7 +74,7 @@ struct convolution
...
@@ -73,7 +74,7 @@ struct convolution
if
(
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
if
(
input
.
lens
().
at
(
1
)
!=
(
weights
.
lens
().
at
(
1
)
*
group
))
MIGRAPHX_THROW
(
"CONVOLUTION: Mismatch channel numbers"
);
MIGRAPHX_THROW
(
"CONVOLUTION: Mismatch channel numbers"
);
std
::
vector
<
size_t
>
output_lens
{
in
put
.
lens
()
[
0
],
weights
.
lens
()[
0
]};
std
::
vector
<
size_t
>
output_lens
{
in
_
lens
[
0
],
weights
.
lens
()[
0
]};
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
for
(
size_t
i
=
0
;
i
<
kdims
;
i
++
)
{
{
...
@@ -82,7 +83,7 @@ struct convolution
...
@@ -82,7 +83,7 @@ struct convolution
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
padding_factor
=
padding
[
i
]
+
padding
[
i
+
kdims
];
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
output_lens
.
push_back
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
in
put
.
lens
()
[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
(
in
_
lens
[
i
+
2
]
-
(
1
+
dilation
[
i
]
*
(
weights
.
lens
()[
i
+
2
]
-
1
))
+
padding_factor
)
/
padding_factor
)
/
stride
[
i
]
+
stride
[
i
]
+
1
)));
1
)));
...
...
src/include/migraphx/op/reshape.hpp
View file @
60a0f286
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <migraphx/lifetime.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <utility>
#include <iostream>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -30,6 +31,13 @@ struct reshape
...
@@ -30,6 +31,13 @@ struct reshape
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
// input shape is dynamic, return dim directly
if
(
inputs
.
front
().
dynamic
())
{
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
return
{
inputs
.
front
().
type
(),
rdims
};
}
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
...
...
src/include/migraphx/op/shape_op.hpp
View file @
60a0f286
...
@@ -20,7 +20,7 @@ struct shape_op
...
@@ -20,7 +20,7 @@ struct shape_op
return
{
shape
::
int64_type
,
lens
};
return
{
shape
::
int64_type
,
lens
};
}
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
lens
=
args
.
front
().
get_shape
().
lens
();
auto
lens
=
args
.
front
().
get_shape
().
lens
();
...
...
src/include/migraphx/shape.hpp
View file @
60a0f286
...
@@ -125,6 +125,8 @@ struct shape
...
@@ -125,6 +125,8 @@ struct shape
bool
standard
()
const
;
bool
standard
()
const
;
/// Returns true if all strides are equal to 0 (scalar tensor)
/// Returns true if all strides are equal to 0 (scalar tensor)
bool
scalar
()
const
;
bool
scalar
()
const
;
/// Return true if any dim is 0
bool
dynamic
()
const
;
shape
normalize_standard
()
const
;
shape
normalize_standard
()
const
;
...
...
src/instruction.cpp
View file @
60a0f286
...
@@ -280,7 +280,7 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
...
@@ -280,7 +280,7 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
bool
instruction
::
can_eval
()
const
bool
instruction
::
can_eval
()
const
{
{
if
(
op
.
name
()
==
"@literal"
)
if
(
op
.
name
()
==
"@literal"
or
op
.
name
()
==
"shape"
)
{
{
return
true
;
return
true
;
}
}
...
@@ -301,10 +301,18 @@ argument instruction::eval(bool check_eval) const
...
@@ -301,10 +301,18 @@ argument instruction::eval(bool check_eval) const
{
{
return
this
->
get_literal
().
get_argument
();
return
this
->
get_literal
().
get_argument
();
}
}
else
if
(
op
.
name
()
==
"shape"
)
{
argument
arg
{
this
->
inputs
().
front
()
->
get_shape
()};
return
normalized_operator
().
compute
(
result
,
{
arg
});
}
if
(
is_context_free
(
op
))
if
(
is_context_free
(
op
))
{
{
if
(
check_eval
and
not
this
->
can_eval
())
if
(
check_eval
and
not
this
->
can_eval
())
{
return
{};
return
{};
}
std
::
vector
<
argument
>
args
;
std
::
vector
<
argument
>
args
;
std
::
transform
(
this
->
inputs
().
begin
(),
std
::
transform
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
this
->
inputs
().
end
(),
...
...
src/onnx/onnx_parser.cpp
View file @
60a0f286
...
@@ -248,10 +248,10 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -248,10 +248,10 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
}
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
map_input_dims
.
count
(
name
)
>
0
)
//
if(map_input_dims.count(name) > 0)
{
//
{
dims
=
map_input_dims
.
at
(
name
);
//
dims = map_input_dims.at(name);
}
//
}
shape
s
=
parse_type
(
input
.
type
(),
dims
);
shape
s
=
parse_type
(
input
.
type
(),
dims
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
...
@@ -262,6 +262,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -262,6 +262,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
{
std
::
cout
<<
"node_op_type = "
<<
node
.
op_type
()
<<
std
::
endl
;
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
{
...
@@ -404,10 +405,10 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -404,10 +405,10 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
const
std
::
vector
<
std
::
size_t
>&
input_dims
)
const
{
{
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
shape
::
type_t
shape_type
=
get_type
(
t
.
tensor_type
().
elem_type
());
if
(
!
input_dims
.
empty
())
//
if(!input_dims.empty())
{
//
{
return
{
shape_type
,
input_dims
};
//
return {shape_type, input_dims};
}
//
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
...
@@ -419,13 +420,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -419,13 +420,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{
{
if
(
static_cast
<
int
>
(
d
.
dim_value
())
<=
0
)
if
(
static_cast
<
int
>
(
d
.
dim_value
())
<=
0
)
{
{
return
default_dim_value
;
// return default_dim_value;
return
0
;
}
}
return
d
.
dim_value
();
return
d
.
dim_value
();
}
}
else
else
{
{
return
default_dim_value
;
// return default_dim_value;
return
0
;
}
}
});
});
...
...
src/shape.cpp
View file @
60a0f286
...
@@ -272,6 +272,13 @@ bool shape::scalar() const
...
@@ -272,6 +272,13 @@ bool shape::scalar() const
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
size_t
(
0
))
==
0
;
std
::
accumulate
(
this
->
strides
().
begin
(),
this
->
strides
().
end
(),
std
::
size_t
(
0
))
==
0
;
}
}
bool
shape
::
dynamic
()
const
{
if
(
scalar
())
return
false
;
const
auto
&
lens
=
this
->
lens
();
return
std
::
find
(
lens
.
begin
(),
lens
.
end
(),
0
)
!=
lens
.
end
();
}
bool
shape
::
standard
()
const
{
return
impl
->
m_standard
;
}
bool
shape
::
standard
()
const
{
return
impl
->
m_standard
;
}
shape
shape
::
normalize_standard
()
const
shape
shape
::
normalize_standard
()
const
...
...
src/targets/gpu/kernels/include/migraphx/kernels/types.hpp
View file @
60a0f286
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
namespace
migraphx
{
namespace
migraphx
{
using
index_int
=
std
::
u
int32_t
;
using
index_int
=
std
::
int32_t
;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment