Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d4c643fb
Unverified
Commit
d4c643fb
authored
Jul 24, 2019
by
mvermeulen
Committed by
GitHub
Jul 24, 2019
Browse files
Merge branch 'develop' into enable-miopen-hipclang
parents
d8922562
abf1b8e4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
101 additions
and
67 deletions
+101
-67
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+20
-50
src/targets/gpu/device/gather.cpp
src/targets/gpu/device/gather.cpp
+1
-1
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+18
-0
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+1
-0
src/tf/tf.cpp
src/tf/tf.cpp
+27
-8
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+16
-3
test/onnx/reshape_non_standard.onnx
test/onnx/reshape_non_standard.onnx
+0
-0
test/tf/gather_test.pb
test/tf/gather_test.pb
+0
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+18
-5
No files found.
src/onnx/onnx.cpp
View file @
d4c643fb
...
@@ -66,8 +66,8 @@ struct onnx_parser
...
@@ -66,8 +66,8 @@ struct onnx_parser
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_mem_op
(
"ArgMax"
,
&
onnx_parser
::
parse_argmax
);
add_mem_op
(
"ArgMax"
,
&
onnx_parser
::
parse_
arg_op
<
op
::
argmax
>
);
add_mem_op
(
"ArgMin"
,
&
onnx_parser
::
parse_argmin
);
add_mem_op
(
"ArgMin"
,
&
onnx_parser
::
parse_
arg_op
<
op
::
argmin
>
);
add_mem_op
(
"Cast"
,
&
onnx_parser
::
parse_cast
);
add_mem_op
(
"Cast"
,
&
onnx_parser
::
parse_cast
);
add_mem_op
(
"Clip"
,
&
onnx_parser
::
parse_clip
);
add_mem_op
(
"Clip"
,
&
onnx_parser
::
parse_clip
);
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
...
@@ -86,8 +86,8 @@ struct onnx_parser
...
@@ -86,8 +86,8 @@ struct onnx_parser
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"MatMul"
,
&
onnx_parser
::
parse_matmul
);
add_mem_op
(
"MatMul"
,
&
onnx_parser
::
parse_matmul
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
);
add_mem_op
(
"Softmax"
,
&
onnx_parser
::
parse_softmax
<
op
::
softmax
>
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_logsoftmax
);
add_mem_op
(
"LogSoftmax"
,
&
onnx_parser
::
parse_
softmax
<
op
::
logsoftmax
>
);
add_mem_op
(
"Squeeze"
,
&
onnx_parser
::
parse_squeeze
);
add_mem_op
(
"Squeeze"
,
&
onnx_parser
::
parse_squeeze
);
add_mem_op
(
"Unsqueeze"
,
&
onnx_parser
::
parse_unsqueeze
);
add_mem_op
(
"Unsqueeze"
,
&
onnx_parser
::
parse_unsqueeze
);
add_mem_op
(
"Slice"
,
&
onnx_parser
::
parse_slice
);
add_mem_op
(
"Slice"
,
&
onnx_parser
::
parse_slice
);
...
@@ -261,19 +261,10 @@ struct onnx_parser
...
@@ -261,19 +261,10 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
}
instruction_ref
template
<
class
Op
>
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
parse_softmax
(
const
std
::
string
&
,
{
const
attribute_map
&
attributes
,
auto
dims
=
args
.
front
()
->
get_shape
().
lens
();
std
::
vector
<
instruction_ref
>
args
)
auto
r
=
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
args
.
front
());
auto
s
=
prog
.
add_instruction
(
op
::
softmax
{},
r
);
return
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
}
instruction_ref
parse_logsoftmax
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
int
axis
=
1
;
int
axis
=
1
;
if
(
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"axis"
))
...
@@ -281,10 +272,11 @@ struct onnx_parser
...
@@ -281,10 +272,11 @@ struct onnx_parser
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
}
return
prog
.
add_instruction
(
op
::
logsoftmax
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
Op
{
axis
},
std
::
move
(
args
));
}
}
instruction_ref
parse_argmax
(
const
std
::
string
&
,
template
<
class
Op
>
instruction_ref
parse_arg_op
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
{
{
...
@@ -302,39 +294,12 @@ struct onnx_parser
...
@@ -302,39 +294,12 @@ struct onnx_parser
if
(
keep_dims
==
0
)
if
(
keep_dims
==
0
)
{
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
auto
ins
=
prog
.
add_instruction
(
Op
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
}
}
else
else
{
{
return
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
Op
{
axis
},
std
::
move
(
args
));
}
}
instruction_ref
parse_argmin
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
int64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
static_cast
<
int64_t
>
(
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
());
}
int
keep_dims
=
1
;
if
(
contains
(
attributes
,
"keepdims"
))
{
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
if
(
keep_dims
==
0
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
axis
}},
ins
);
}
else
{
return
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
}
}
}
}
...
@@ -470,6 +435,12 @@ struct onnx_parser
...
@@ -470,6 +435,12 @@ struct onnx_parser
check_arg_empty
(
s
,
"Reshape: dynamic shape is not supported"
);
check_arg_empty
(
s
,
"Reshape: dynamic shape is not supported"
);
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
}
if
(
!
args
[
0
]
->
get_shape
().
standard
())
{
args
[
0
]
=
prog
.
add_instruction
(
op
::
contiguous
{},
args
[
0
]);
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
}
...
@@ -849,7 +820,7 @@ struct onnx_parser
...
@@ -849,7 +820,7 @@ struct onnx_parser
{
{
dtype
=
parse_value
(
attributes
.
at
(
"dtype"
)).
at
<
int
>
();
dtype
=
parse_value
(
attributes
.
at
(
"dtype"
)).
at
<
int
>
();
}
}
migraphx
::
shape
::
type_t
type
=
get_type
(
dtype
);
shape
::
type_t
type
=
get_type
(
dtype
);
if
(
contains
(
attributes
,
"input_as_shape"
))
if
(
contains
(
attributes
,
"input_as_shape"
))
{
{
...
@@ -972,7 +943,6 @@ struct onnx_parser
...
@@ -972,7 +943,6 @@ struct onnx_parser
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
arg_s
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
arg_s
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
auto
out_lens
=
compute_broadcasted_lens
(
in_lens
,
dims
);
auto
out_lens
=
compute_broadcasted_lens
(
in_lens
,
dims
);
return
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
0
]);
return
prog
.
add_instruction
(
op
::
multibroadcast
{
out_lens
},
args
[
0
]);
}
}
...
...
src/targets/gpu/device/gather.cpp
View file @
d4c643fb
...
@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
...
@@ -25,7 +25,7 @@ argument gather(hipStream_t stream, argument result, argument arg1, argument arg
arg2
.
visit
([
&
](
auto
indices
)
{
arg2
.
visit
([
&
](
auto
indices
)
{
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
const
auto
*
indices_ptr
=
device_cast
(
indices
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
auto
*
output_ptr
=
device_cast
(
output
.
data
());
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
gs_launch
(
stream
,
nelements
,
256
)([
=
](
auto
i
)
{
auto
idx
=
out_comp
.
multi
(
i
);
auto
idx
=
out_comp
.
multi
(
i
);
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
idx
[
axis_index
]
=
indices_ptr
[
idx
[
axis_index
]];
output_ptr
[
i
]
=
input
[
idx
];
output_ptr
[
i
]
=
input
[
idx
];
...
...
src/targets/gpu/gemm.cpp
View file @
d4c643fb
...
@@ -167,10 +167,28 @@ rb_type<T>* to_rocblas_type(T* x)
...
@@ -167,10 +167,28 @@ rb_type<T>* to_rocblas_type(T* x)
rocblas_half
to_rocblas_type
(
half
x
)
{
return
reinterpret_cast
<
const
rocblas_half
&>
(
x
);
}
rocblas_half
to_rocblas_type
(
half
x
)
{
return
reinterpret_cast
<
const
rocblas_half
&>
(
x
);
}
void
miopen_gemm
::
batch_not_transposed
(
const
std
::
vector
<
std
::
size_t
>&
strides
)
const
{
if
(
strides
.
size
()
<=
2
)
return
;
auto
dim_0
=
strides
.
size
()
-
2
;
auto
matrix_size
=
std
::
max
(
strides
[
dim_0
],
strides
[
dim_0
+
1
]);
std
::
vector
<
std
::
size_t
>
batch
(
strides
.
begin
(),
strides
.
begin
()
+
dim_0
);
if
(
std
::
adjacent_find
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
<
j
or
i
<
matrix_size
or
j
<
matrix_size
);
})
!=
batch
.
end
())
{
MIGRAPHX_THROW
(
"DOT: batch size {"
+
to_string_range
(
strides
)
+
"} is transposed!"
);
}
}
shape
miopen_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
miopen_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
std
::
vector
<
shape
>
input_shapes
(
inputs
.
begin
(),
inputs
.
begin
()
+
inputs
.
size
()
-
1
);
std
::
vector
<
shape
>
input_shapes
(
inputs
.
begin
(),
inputs
.
begin
()
+
inputs
.
size
()
-
1
);
check_shapes
{
input_shapes
}.
not_broadcasted
();
check_shapes
{
input_shapes
}.
not_broadcasted
();
batch_not_transposed
(
inputs
[
0
].
strides
());
batch_not_transposed
(
inputs
[
1
].
strides
());
return
op
.
compute_shape
(
input_shapes
);
return
op
.
compute_shape
(
input_shapes
);
}
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
d4c643fb
...
@@ -24,6 +24,7 @@ struct miopen_gemm
...
@@ -24,6 +24,7 @@ struct miopen_gemm
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
;
argument
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
void
batch_not_transposed
(
const
std
::
vector
<
std
::
size_t
>&
strides
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
...
src/tf/tf.cpp
View file @
d4c643fb
...
@@ -174,13 +174,14 @@ struct tf_parser
...
@@ -174,13 +174,14 @@ struct tf_parser
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"ExpandDims"
,
&
tf_parser
::
parse_expanddims
,
false
);
add_mem_op
(
"ExpandDims"
,
&
tf_parser
::
parse_expanddims
,
false
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"GatherV2"
,
&
tf_parser
::
parse_gather
,
false
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
<
op
::
softmax
>
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
add_mem_op
(
"Transpose"
,
&
tf_parser
::
parse_transpose
,
false
);
add_mem_op
(
"Transpose"
,
&
tf_parser
::
parse_transpose
,
false
);
...
@@ -525,6 +526,14 @@ struct tf_parser
...
@@ -525,6 +526,14 @@ struct tf_parser
return
prog
.
add_instruction
(
op
::
reshape
{
new_dims
},
args
[
0
]);
return
prog
.
add_instruction
(
op
::
reshape
{
new_dims
},
args
[
0
]);
}
}
instruction_ref
parse_gather
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
int
axis
=
args
[
2
]
->
eval
().
at
<
int32_t
>
();
op
::
gather
op
{
axis
};
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
args
[
1
]});
}
instruction_ref
instruction_ref
parse_matmul
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_matmul
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
...
@@ -724,14 +733,24 @@ struct tf_parser
...
@@ -724,14 +733,24 @@ struct tf_parser
}
}
}
}
instruction_ref
// template to facilitate the logsoftmax later
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
template
<
class
Op
>
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
auto
dims
=
args
.
front
()
->
get_shape
().
lens
();
int
axis
=
-
1
;
auto
r
=
auto
num_dims
=
args
[
0
]
->
get_shape
().
lens
().
size
();
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
args
.
front
());
if
(
contains
(
attributes
,
"axis"
))
auto
s
=
prog
.
add_instruction
(
op
::
softmax
{},
r
);
{
return
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
axis
=
static_cast
<
int
>
(
attributes
.
at
(
"axis"
).
i
());
}
if
(
axis
<
0
)
{
axis
+=
num_dims
;
}
return
prog
.
add_instruction
(
Op
{
axis
},
make_contiguous
(
args
[
0
]));
}
}
instruction_ref
parse_squeeze
(
const
std
::
string
&
,
instruction_ref
parse_squeeze
(
const
std
::
string
&
,
...
...
test/onnx/onnx_test.cpp
View file @
d4c643fb
...
@@ -423,9 +423,7 @@ TEST_CASE(softmax_test)
...
@@ -423,9 +423,7 @@ TEST_CASE(softmax_test)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
}});
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
}});
auto
r
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
3
,
1
,
1
}},
l0
);
p
.
add_instruction
(
migraphx
::
op
::
softmax
{
1
},
l0
);
auto
s
=
p
.
add_instruction
(
migraphx
::
op
::
softmax
{},
r
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
3
}},
s
);
auto
prog
=
migraphx
::
parse_onnx
(
"softmax_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"softmax_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -447,6 +445,21 @@ TEST_CASE(reshape_test)
...
@@ -447,6 +445,21 @@ TEST_CASE(reshape_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
reshape_non_standard
)
{
migraphx
::
program
p
;
migraphx
::
op
::
reshape
op
;
std
::
vector
<
int64_t
>
reshape_dims
{
4
,
3
,
2
};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
tran_x
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
1
}},
x
);
auto
cont_x
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
tran_x
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
4
,
3
,
2
}},
cont_x
);
auto
prog
=
migraphx
::
parse_onnx
(
"reshape_non_standard.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
shape_test
)
TEST_CASE
(
shape_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/onnx/reshape_non_standard.onnx
0 → 100644
View file @
d4c643fb
File added
test/tf/gather_test.pb
0 → 100644
View file @
d4c643fb
File added
test/tf/tf_test.cpp
View file @
d4c643fb
...
@@ -209,6 +209,22 @@ TEST_CASE(expanddims_test_neg_dims)
...
@@ -209,6 +209,22 @@ TEST_CASE(expanddims_test_neg_dims)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
gather_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}});
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
}},
{
1
,
1
}});
p
.
add_literal
(
1
);
int
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
l0
,
l1
);
auto
prog
=
optimize_tf
(
"gather_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
identity_test
)
TEST_CASE
(
identity_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -399,11 +415,8 @@ TEST_CASE(rsqrt_test)
...
@@ -399,11 +415,8 @@ TEST_CASE(rsqrt_test)
TEST_CASE
(
softmax_test
)
TEST_CASE
(
softmax_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
}});
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
}});
auto
dims
=
l0
->
get_shape
().
lens
();
p
.
add_instruction
(
migraphx
::
op
::
softmax
{
1
},
l0
);
auto
r
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
l0
);
auto
s
=
p
.
add_instruction
(
migraphx
::
op
::
softmax
{},
r
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
auto
prog
=
optimize_tf
(
"softmax_test.pb"
,
false
);
auto
prog
=
optimize_tf
(
"softmax_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment