Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8c2d316e
Commit
8c2d316e
authored
Jul 24, 2018
by
Scott Thornton
Browse files
Able to read raw_data from onnx (at least in the case of Reshape)
parent
71c777bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
0 deletions
+93
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+93
-0
No files found.
src/onnx/onnx.cpp
View file @
8c2d316e
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <unordered_map>
#include <unordered_map>
#include <functional>
#include <functional>
#include <array>
#include <array>
#include <vector>
#include <migraph/fallthrough.hpp>
#include <migraph/fallthrough.hpp>
#include <migraph/program.hpp>
#include <migraph/program.hpp>
...
@@ -314,6 +315,98 @@ struct onnx_parser
...
@@ -314,6 +315,98 @@ struct onnx_parser
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
if
(
t
.
has_raw_data
())
{
std
::
string
s
=
t
.
raw_data
();
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT
)
{
std
::
vector
<
float
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
float_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT8
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT8
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT16
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT16
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT32
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT64
)
{
std
::
vector
<
int64_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int64_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
STRING
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
BOOL
)
{
std
::
vector
<
int32_t
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
int32_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT16
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
DOUBLE
)
{
std
::
vector
<
double
>
raw
(
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
()));
memcpy
(
raw
.
data
(),
s
.
data
(),
s
.
length
());
return
literal
{{
shape
::
double_type
,
dims
},
raw
};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT32
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX128
)
{
throw
std
::
runtime_error
(
""
);
}
else
{
MIGRAPH_THROW
(
"Invalid tensor type"
);
}
}
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment