Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fdd4e403
Commit
fdd4e403
authored
Apr 18, 2019
by
Khalique
Browse files
fix shape constructor and data types
parent
77ef0c1d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
13 deletions
+13
-13
src/include/migraphx/op/squeeze.hpp
src/include/migraphx/op/squeeze.hpp
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+2
-2
src/tf/tf.cpp
src/tf/tf.cpp
+8
-8
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+2
-2
No files found.
src/include/migraphx/op/squeeze.hpp
View file @
fdd4e403
...
@@ -58,7 +58,7 @@ struct squeeze
...
@@ -58,7 +58,7 @@ struct squeeze
if
(
new_lens
.
empty
())
if
(
new_lens
.
empty
())
{
{
return
shape
{
type
,
{
1
},
{
0
}
};
return
shape
{
type
};
}
}
else
else
{
{
...
...
src/onnx/onnx.cpp
View file @
fdd4e403
...
@@ -1432,7 +1432,7 @@ struct onnx_parser
...
@@ -1432,7 +1432,7 @@ struct onnx_parser
{
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
if
(
dims
.
empty
())
return
literal
{{
shape_type
,
{
1
},
{
0
}
},
data
};
return
literal
{{
shape_type
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
}
}
...
@@ -1440,7 +1440,7 @@ struct onnx_parser
...
@@ -1440,7 +1440,7 @@ struct onnx_parser
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
T
data
)
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
T
data
)
{
{
if
(
dims
.
empty
())
if
(
dims
.
empty
())
return
literal
{{
shape_type
,
{
1
},
{
0
}
},
data
.
begin
(),
data
.
end
()};
return
literal
{{
shape_type
},
data
.
begin
(),
data
.
end
()};
return
literal
{{
shape_type
,
dims
},
data
.
begin
(),
data
.
end
()};
return
literal
{{
shape_type
,
dims
},
data
.
begin
(),
data
.
end
()};
}
}
...
...
src/tf/tf.cpp
View file @
fdd4e403
...
@@ -751,17 +751,17 @@ struct tf_parser
...
@@ -751,17 +751,17 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int
8
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
u
int
16
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int
16
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int
8
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
...
@@ -815,11 +815,11 @@ struct tf_parser
...
@@ -815,11 +815,11 @@ struct tf_parser
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
create_literal
(
shape
::
int
32
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
return
create_literal
(
shape
::
int
8
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
return
create_literal
(
shape
::
int
32
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
return
create_literal
(
shape
::
u
int
16
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
create_literal
(
shape
::
int
32
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
return
create_literal
(
shape
::
int
16
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_INT32
:
case
tensorflow
::
DataType
::
DT_INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
...
@@ -916,7 +916,7 @@ struct tf_parser
...
@@ -916,7 +916,7 @@ struct tf_parser
{
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if
(
dims
.
empty
()
or
(
dims
.
size
()
==
1
and
dims
.
front
()
==
1
))
if
(
dims
.
empty
()
or
(
dims
.
size
()
==
1
and
dims
.
front
()
==
1
))
return
literal
{{
shape_type
,
{
1
},
{
0
}
},
data
};
return
literal
{{
shape_type
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
}
}
};
};
...
...
test/tf/tf_test.cpp
View file @
fdd4e403
...
@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
...
@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int
axis
=
1
;
int
axis
=
1
;
// tf uses axis as the third input, and it is in int32 format
// tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser)
// add the literal using a vector in order to set stride to 1 (like in tf parser)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
1
},
{
0
}
},
std
::
vector
<
int
>
{
axis
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
},
std
::
vector
<
int
>
{
axis
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
std
::
size_t
>
(
axis
)},
l0
,
l1
);
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
std
::
size_t
>
(
axis
)},
l0
,
l1
);
auto
prog
=
migraphx
::
parse_tf
(
"concat_test.pb"
,
false
);
auto
prog
=
migraphx
::
parse_tf
(
"concat_test.pb"
,
false
);
...
@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
...
@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE
(
const_test
)
TEST_CASE
(
const_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}
},
std
::
vector
<
float
>
{
1.0
f
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
std
::
vector
<
float
>
{
1.0
f
});
auto
prog
=
migraphx
::
parse_tf
(
"constant_test.pb"
,
false
);
auto
prog
=
migraphx
::
parse_tf
(
"constant_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment