Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e23d5c4
"vscode:/vscode.git/clone" did not exist on "b936bcc72eae0a618173dc9b5676fb193739d055"
Commit
7e23d5c4
authored
Jul 25, 2018
by
Scott Thornton
Browse files
Deleted memcpy for Paul. Added fix in Reshape for cases with -1 dimension.
parent
8c2d316e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
32 deletions
+22
-32
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+14
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+8
-32
No files found.
src/include/migraph/operators.hpp
View file @
7e23d5c4
...
@@ -341,11 +341,25 @@ struct reshape
...
@@ -341,11 +341,25 @@ struct reshape
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
auto
n_neg_dims
=
std
::
count
(
dims
.
begin
(),
dims
.
end
(),
-
1
);
if
(
n_neg_dims
>
1
)
MIGRAPH_THROW
(
"Dimensions for reshape can only have one -1 dim"
);
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
{
if
(
dims
[
i
]
==
0
)
if
(
dims
[
i
]
==
0
)
rdims
[
i
]
=
idims
[
i
];
rdims
[
i
]
=
idims
[
i
];
}
}
if
(
n_neg_dims
>
0
)
{
size_t
missing_dim
=
-
inputs
.
front
().
elements
()
/
std
::
accumulate
(
rdims
.
begin
(),
rdims
.
end
(),
1
,
std
::
multiplies
<
int64_t
>
());
for
(
std
::
size_t
i
=
0
;
i
<
rdims
.
size
();
i
++
)
{
if
(
dims
[
i
]
==
-
1
)
rdims
[
i
]
=
missing_dim
;
}
}
if
(
dims
.
back
()
==
-
1
)
if
(
dims
.
back
()
==
-
1
)
{
{
rdims
.
pop_back
();
rdims
.
pop_back
();
...
...
src/onnx/onnx.cpp
View file @
7e23d5c4
...
@@ -320,10 +320,7 @@ struct onnx_parser
...
@@ -320,10 +320,7 @@ struct onnx_parser
std
::
string
s
=
t
.
raw_data
();
std
::
string
s
=
t
.
raw_data
();
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT
)
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT
)
{
{
std
::
vector
<
float
>
raw
(
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
float_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT8
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT8
)
{
{
...
@@ -331,38 +328,23 @@ struct onnx_parser
...
@@ -331,38 +328,23 @@ struct onnx_parser
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT8
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT8
)
{
{
std
::
vector
<
int32_t
>
raw
(
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT16
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT16
)
{
{
std
::
vector
<
int32_t
>
raw
(
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT16
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT16
)
{
{
std
::
vector
<
int32_t
>
raw
(
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT32
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT32
)
{
{
std
::
vector
<
int32_t
>
raw
(
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT64
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT64
)
{
{
std
::
vector
<
int64_t
>
raw
(
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int64_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
STRING
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
STRING
)
{
{
...
@@ -370,10 +352,7 @@ struct onnx_parser
...
@@ -370,10 +352,7 @@ struct onnx_parser
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
BOOL
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
BOOL
)
{
{
std
::
vector
<
int32_t
>
raw
(
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT16
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT16
)
{
{
...
@@ -381,10 +360,7 @@ struct onnx_parser
...
@@ -381,10 +360,7 @@ struct onnx_parser
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
DOUBLE
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
DOUBLE
)
{
{
std
::
vector
<
double
>
raw
(
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
double_type
,
dims
},
raw
};
}
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT32
)
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT32
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment