Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a55c6eec
Unverified
Commit
a55c6eec
authored
May 01, 2019
by
mvermeulen
Committed by
GitHub
May 01, 2019
Browse files
Merge branch 'develop' into copy_program
parents
58a845fa
8245bcaf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
80 additions
and
51 deletions
+80
-51
README.md
README.md
+8
-4
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+7
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+38
-24
src/tf/tf.cpp
src/tf/tf.cpp
+24
-16
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+1
-2
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+2
-2
No files found.
README.md
View file @
a55c6eec
...
@@ -12,14 +12,18 @@ AMD's graph optimization engine.
...
@@ -12,14 +12,18 @@ AMD's graph optimization engine.
## Installing the dependencies
## Installing the dependencies
The dependencies can be installed with the
`install_deps.cmake`
, script:
`cmake -P install_deps.cmake`
.
Dependencies can be installed using the ROCm build tool
[
rbuild
](
https://github.com/RadeonOpenCompute/rbuild
)
.
This will install by default to
`/usr/local`
but it can be installed in another location with
`--prefix`
argument:
To install rbuild:
```
```
cmake -P install_deps.cmake --prefix /some/local/dir
pip install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
```
```
To build dependencies along with MIGraphX
```
rbuild build -d depend --cxx=/opt/rocm/bin/hcc
```
This builds dependencies in the subdirectory named depend and then builds MIGraphX using these dependencies.
## Building MIGraphX from source
## Building MIGraphX from source
...
...
src/include/migraphx/op/unsqueeze.hpp
View file @
a55c6eec
...
@@ -29,9 +29,13 @@ struct unsqueeze
...
@@ -29,9 +29,13 @@ struct unsqueeze
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
std
::
string
name
()
const
{
return
"unsqueeze"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
auto
input_shape
=
inputs
[
0
];
auto
input_shape
=
inputs
[
0
];
auto
type
=
input_shape
.
type
();
auto
type
=
input_shape
.
type
();
auto
old_lens
=
input_shape
.
lens
();
auto
old_lens
=
input_shape
.
lens
();
if
(
input_shape
.
scalar
())
return
shape
{
type
,
old_lens
};
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
size_t
new_size
=
old_lens
.
size
()
+
axes
.
size
();
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
vector
<
std
::
size_t
>
new_lens
(
new_size
);
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
...
...
src/onnx/onnx.cpp
View file @
a55c6eec
...
@@ -1361,28 +1361,26 @@ struct onnx_parser
...
@@ -1361,28 +1361,26 @@ struct onnx_parser
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
{
dims
=
{
1
};
}
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
}
,
s
.
data
()
}
;
case
onnx
::
TensorProto
::
FLOAT
:
return
create_
literal
(
shape
::
float_type
,
dims
,
s
.
data
()
)
;
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT16
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
@@ -1394,21 +1392,21 @@ struct onnx_parser
...
@@ -1394,21 +1392,21 @@ struct onnx_parser
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
}
,
t
.
float_data
()
.
begin
(),
t
.
float_data
().
end
()}
;
return
create_
literal
(
shape
::
float_type
,
dims
,
t
.
float_data
()
)
;
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
INT64
:
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
}
,
t
.
int64_data
()
.
begin
(),
t
.
int64_data
().
end
()}
;
return
create_
literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
()
)
;
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
t
.
int32_data
()
.
begin
(),
t
.
int32_data
().
end
()}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
()
)
;
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
FLOAT16
:
{
{
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
...
@@ -1417,11 +1415,10 @@ struct onnx_parser
...
@@ -1417,11 +1415,10 @@ struct onnx_parser
data_uint16
.
end
(),
data_uint16
.
end
(),
std
::
back_inserter
(
data_half
),
std
::
back_inserter
(
data_half
),
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
return
literal
{{
shape
::
half_type
,
dims
}
,
data_half
.
begin
(),
data_half
.
end
()}
;
return
create_
literal
(
shape
::
half_type
,
dims
,
data_half
)
;
}
}
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
{
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
@@ -1430,6 +1427,23 @@ struct onnx_parser
...
@@ -1430,6 +1427,23 @@ struct onnx_parser
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
const
char
*
data
)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
return
literal
{{
shape_type
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
std
::
is_pointer
<
T
>{})
>
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
T
data
)
{
if
(
dims
.
empty
())
return
literal
{{
shape_type
},
data
.
begin
(),
data
.
end
()};
return
literal
{{
shape_type
,
dims
},
data
.
begin
(),
data
.
end
()};
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
{
shape
::
type_t
shape_type
{};
shape
::
type_t
shape_type
{};
...
...
src/tf/tf.cpp
View file @
a55c6eec
...
@@ -741,10 +741,6 @@ struct tf_parser
...
@@ -741,10 +741,6 @@ struct tf_parser
static
literal
parse_tensor
(
const
tensorflow
::
TensorProto
&
t
)
static
literal
parse_tensor
(
const
tensorflow
::
TensorProto
&
t
)
{
{
std
::
vector
<
size_t
>
dims
=
parse_dims
(
t
.
tensor_shape
());
std
::
vector
<
size_t
>
dims
=
parse_dims
(
t
.
tensor_shape
());
if
(
dims
.
empty
())
{
dims
=
{
1
};
}
size_t
shape_size
=
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
size_t
shape_size
=
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
{
{
...
@@ -755,17 +751,17 @@ struct tf_parser
...
@@ -755,17 +751,17 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int
8
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
u
int
16
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int
16
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int
32
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int
8
_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
...
@@ -815,21 +811,23 @@ struct tf_parser
...
@@ -815,21 +811,23 @@ struct tf_parser
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
get_data_vals
(
t
.
float_val
(),
shape_size
)};
return
create_literal
(
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int
32
_type
,
dims
}
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
}
;
return
create_
literal
(
shape
::
int
8
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
)
;
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
int
32
_type
,
dims
}
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
}
;
return
create_
literal
(
shape
::
u
int
16
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
)
;
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int
32
_type
,
dims
}
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
}
;
return
create_
literal
(
shape
::
int
16
_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
)
;
case
tensorflow
::
DataType
::
DT_INT32
:
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
)
)
;
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
get_data_vals
(
t
.
int64_val
(),
shape_size
)};
return
create_literal
(
shape
::
int64_type
,
dims
,
get_data_vals
(
t
.
int64_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
}
,
get_data_vals
(
t
.
bool_val
(),
shape_size
)
}
;
return
create_
literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
bool_val
(),
shape_size
)
)
;
case
tensorflow
::
DataType
::
DT_HALF
:
case
tensorflow
::
DataType
::
DT_HALF
:
{
{
std
::
vector
<
int
>
data_int32
=
get_data_vals
(
t
.
half_val
(),
shape_size
);
std
::
vector
<
int
>
data_int32
=
get_data_vals
(
t
.
half_val
(),
shape_size
);
...
@@ -839,7 +837,7 @@ struct tf_parser
...
@@ -839,7 +837,7 @@ struct tf_parser
data_uint16
.
end
(),
data_uint16
.
end
(),
std
::
back_inserter
(
data_half
),
std
::
back_inserter
(
data_half
),
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
return
literal
{{
shape
::
half_type
,
dims
}
,
data_half
}
;
return
create_
literal
(
shape
::
half_type
,
dims
,
data_half
)
;
}
}
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
...
@@ -911,6 +909,16 @@ struct tf_parser
...
@@ -911,6 +909,16 @@ struct tf_parser
[](
tensorflow
::
TensorShapeProto_Dim
dim
)
{
return
dim
.
size
();
});
[](
tensorflow
::
TensorShapeProto_Dim
dim
)
{
return
dim
.
size
();
});
return
dims
;
return
dims
;
}
}
template
<
class
T
>
static
literal
create_literal
(
shape
::
type_t
shape_type
,
const
std
::
vector
<
size_t
>&
dims
,
std
::
vector
<
T
>
data
)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if
(
dims
.
empty
()
or
(
dims
.
size
()
==
1
and
dims
.
front
()
==
1
))
return
literal
{{
shape_type
},
data
};
return
literal
{{
shape_type
,
dims
},
data
};
}
};
};
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
...
...
test/onnx/onnx_test.cpp
View file @
a55c6eec
...
@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test)
...
@@ -699,8 +699,7 @@ TEST_CASE(add_scalar_test)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
l1
=
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
{
1
}});
p
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}},
{
1
}});
auto
m0
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l0
);
auto
m0
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l0
);
auto
m1
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l1
);
auto
m1
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
2
,
3
,
4
,
5
}},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
m0
,
m1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
m0
,
m1
);
...
...
test/tf/tf_test.cpp
View file @
a55c6eec
...
@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
...
@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int
axis
=
1
;
int
axis
=
1
;
// tf uses axis as the third input, and it is in int32 format
// tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser)
// add the literal using a vector in order to set stride to 1 (like in tf parser)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
1
}
},
std
::
vector
<
int
>
{
axis
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
},
std
::
vector
<
int
>
{
axis
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
std
::
size_t
>
(
axis
)},
l0
,
l1
);
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
std
::
size_t
>
(
axis
)},
l0
,
l1
);
auto
prog
=
migraphx
::
parse_tf
(
"concat_test.pb"
,
false
);
auto
prog
=
migraphx
::
parse_tf
(
"concat_test.pb"
,
false
);
...
@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
...
@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE
(
const_test
)
TEST_CASE
(
const_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
}
},
std
::
vector
<
float
>
{
1.0
f
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
std
::
vector
<
float
>
{
1.0
f
});
auto
prog
=
migraphx
::
parse_tf
(
"constant_test.pb"
,
false
);
auto
prog
=
migraphx
::
parse_tf
(
"constant_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment