Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8c2d316e
Commit
8c2d316e
authored
Jul 24, 2018
by
Scott Thornton
Browse files
Able to read raw_data from onnx (at least in the case of Reshape)
parent
71c777bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
0 deletions
+93
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+93
-0
No files found.
src/onnx/onnx.cpp
View file @
8c2d316e
...
...
@@ -6,6 +6,7 @@
#include <unordered_map>
#include <functional>
#include <array>
#include <vector>
#include <migraph/fallthrough.hpp>
#include <migraph/program.hpp>
...
...
@@ -314,6 +315,98 @@ struct onnx_parser
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
t
.
has_raw_data
())
{
std
::
string
s
=
t
.
raw_data
();
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT
)
{
std
::
vector
<
float
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
float_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT8
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT8
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT16
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT16
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT32
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT64
)
{
std
::
vector
<
int64_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int64_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
STRING
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
BOOL
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT16
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
DOUBLE
)
{
std
::
vector
<
double
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
double_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT32
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX128
)
{
throw
std
::
runtime_error
(
""
);
}
else
{
MIGRAPH_THROW
(
"Invalid tensor type"
);
}
}
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment