Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9ea01307
Commit
9ea01307
authored
Jun 28, 2019
by
Shucai Xiao
Browse files
remove the keep dim attribute from argmax and argmin operators
parent
cf984059
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
70 deletions
+48
-70
src/include/migraphx/op/argmax.hpp
src/include/migraphx/op/argmax.hpp
+1
-6
src/include/migraphx/op/argmin.hpp
src/include/migraphx/op/argmin.hpp
+1
-6
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+18
-2
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+8
-16
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+11
-13
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+4
-2
test/op_shape_test.cpp
test/op_shape_test.cpp
+5
-25
No files found.
src/include/migraphx/op/argmax.hpp
View file @
9ea01307
...
@@ -19,12 +19,11 @@ namespace op {
...
@@ -19,12 +19,11 @@ namespace op {
struct
argmax
struct
argmax
{
{
int
axis
=
0
;
int
axis
=
0
;
int
keep_dims
=
1
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axis
,
"axis"
)
,
f
(
self
.
keep_dims
,
"keep_dims"
)
);
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
}
std
::
string
name
()
const
{
return
"argmax"
;
}
std
::
string
name
()
const
{
return
"argmax"
;
}
...
@@ -40,10 +39,6 @@ struct argmax
...
@@ -40,10 +39,6 @@ struct argmax
}
}
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
if
(
keep_dims
==
0
)
{
lens
.
erase
(
lens
.
begin
()
+
axis
);
}
return
{
shape
::
int64_type
,
lens
};
return
{
shape
::
int64_type
,
lens
};
}
}
...
...
src/include/migraphx/op/argmin.hpp
View file @
9ea01307
...
@@ -19,12 +19,11 @@ namespace op {
...
@@ -19,12 +19,11 @@ namespace op {
struct
argmin
struct
argmin
{
{
int
axis
=
0
;
int
axis
=
0
;
int
keep_dims
=
1
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axis
,
"axis"
)
,
f
(
self
.
keep_dims
,
"keep_dims"
)
);
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
}
std
::
string
name
()
const
{
return
"argmin"
;
}
std
::
string
name
()
const
{
return
"argmin"
;
}
...
@@ -40,10 +39,6 @@ struct argmin
...
@@ -40,10 +39,6 @@ struct argmin
}
}
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
if
(
keep_dims
==
0
)
{
lens
.
erase
(
lens
.
begin
()
+
axis
);
}
return
{
shape
::
int64_type
,
lens
};
return
{
shape
::
int64_type
,
lens
};
}
}
...
...
src/onnx/onnx.cpp
View file @
9ea01307
...
@@ -284,7 +284,15 @@ struct onnx_parser
...
@@ -284,7 +284,15 @@ struct onnx_parser
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
}
return
prog
.
add_instruction
(
op
::
argmax
{
axis
,
keep_dims
},
std
::
move
(
args
));
if
(
keep_dims
==
0
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
static_cast
<
int64_t
>
(
axis
)}},
ins
);
}
else
{
return
prog
.
add_instruction
(
op
::
argmax
{
axis
},
std
::
move
(
args
));
}
}
}
instruction_ref
parse_argmin
(
const
std
::
string
&
,
instruction_ref
parse_argmin
(
const
std
::
string
&
,
...
@@ -303,7 +311,15 @@ struct onnx_parser
...
@@ -303,7 +311,15 @@ struct onnx_parser
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
}
return
prog
.
add_instruction
(
op
::
argmin
{
axis
,
keep_dims
},
std
::
move
(
args
));
if
(
keep_dims
==
0
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{{
static_cast
<
int64_t
>
(
axis
)}},
ins
);
}
else
{
return
prog
.
add_instruction
(
op
::
argmin
{
axis
},
std
::
move
(
args
));
}
}
}
instruction_ref
instruction_ref
...
...
test/cpu_ops_test.cpp
View file @
9ea01307
...
@@ -1135,8 +1135,7 @@ TEST_CASE(logsoftmax_test_axis_3)
...
@@ -1135,8 +1135,7 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
}
template
<
int
KeepDims
>
TEST_CASE
(
argmax_test_0
)
void
argmax_test_0
()
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
...
@@ -1145,7 +1144,7 @@ void argmax_test_0()
...
@@ -1145,7 +1144,7 @@ void argmax_test_0()
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
0
,
KeepDims
},
dl
);
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
0
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
...
@@ -1154,9 +1153,6 @@ void argmax_test_0()
...
@@ -1154,9 +1153,6 @@ void argmax_test_0()
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmax_test_00
)
{
argmax_test_0
<
0
>
();
}
TEST_CASE
(
argmax_test_01
)
{
argmax_test_0
<
1
>
();
}
TEST_CASE
(
argmax_test_1
)
TEST_CASE
(
argmax_test_1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -1166,7 +1162,7 @@ TEST_CASE(argmax_test_1)
...
@@ -1166,7 +1162,7 @@ TEST_CASE(argmax_test_1)
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
2
,
1
,
2
,
0
,
0
,
2
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
2
,
1
,
2
,
0
,
0
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
1
,
0
},
dl
);
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
1
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
...
@@ -1184,7 +1180,7 @@ TEST_CASE(argmax_test_2)
...
@@ -1184,7 +1180,7 @@ TEST_CASE(argmax_test_2)
std
::
vector
<
int64_t
>
res_gold
=
{
1
,
3
,
2
,
2
,
2
,
3
};
std
::
vector
<
int64_t
>
res_gold
=
{
1
,
3
,
2
,
2
,
2
,
3
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
2
,
0
},
dl
);
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
2
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
...
@@ -1193,8 +1189,7 @@ TEST_CASE(argmax_test_2)
...
@@ -1193,8 +1189,7 @@ TEST_CASE(argmax_test_2)
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
}
template
<
int
KeepDims
>
TEST_CASE
(
argmin_test_0
)
void
argmin_test_0
()
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
...
@@ -1203,7 +1198,7 @@ void argmin_test_0()
...
@@ -1203,7 +1198,7 @@ void argmin_test_0()
std
::
vector
<
int64_t
>
res_gold
=
{
1
,
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
0
,
1
,
0
};
std
::
vector
<
int64_t
>
res_gold
=
{
1
,
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
0
,
1
,
0
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
0
,
KeepDims
},
dl
);
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
0
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
...
@@ -1212,9 +1207,6 @@ void argmin_test_0()
...
@@ -1212,9 +1207,6 @@ void argmin_test_0()
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
}
TEST_CASE
(
argmin_test_00
)
{
argmin_test_0
<
0
>
();
}
TEST_CASE
(
argmin_test_01
)
{
argmin_test_0
<
1
>
();
}
TEST_CASE
(
argmin_test_1
)
TEST_CASE
(
argmin_test_1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -1224,7 +1216,7 @@ TEST_CASE(argmin_test_1)
...
@@ -1224,7 +1216,7 @@ TEST_CASE(argmin_test_1)
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
,
0
,
2
,
0
,
1
,
2
,
0
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
,
0
,
2
,
0
,
1
,
2
,
0
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
1
,
0
},
dl
);
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
1
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
...
@@ -1242,7 +1234,7 @@ TEST_CASE(argmin_test_2)
...
@@ -1242,7 +1234,7 @@ TEST_CASE(argmin_test_2)
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
1
,
0
,
3
,
3
,
2
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
1
,
0
,
3
,
3
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
2
,
0
},
dl
);
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
2
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
std
::
vector
<
int64_t
>
result_vec
;
...
...
test/gpu/miopen.cpp
View file @
9ea01307
...
@@ -611,31 +611,29 @@ template struct test_softmax<1>;
...
@@ -611,31 +611,29 @@ template struct test_softmax<1>;
template
struct
test_softmax
<
2
>;
template
struct
test_softmax
<
2
>;
template
struct
test_softmax
<
3
>;
template
struct
test_softmax
<
3
>;
template
<
class
T
,
int
Axis
,
int
KeepDims
>
template
<
class
T
,
int
Axis
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
KeepDims
>>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
>>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1025
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1025
}};
auto
param
=
p
.
add_parameter
(
"data"
,
s
);
auto
param
=
p
.
add_parameter
(
"data"
,
s
);
p
.
add_instruction
(
T
{
Axis
,
KeepDims
},
param
);
p
.
add_instruction
(
T
{
Axis
},
param
);
return
p
;
return
p
;
}
}
};
};
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
struct
test_conv
:
verify_program
<
test_conv
>
struct
test_conv
:
verify_program
<
test_conv
>
{
{
...
...
test/onnx/onnx_test.cpp
View file @
9ea01307
...
@@ -788,7 +788,8 @@ TEST_CASE(argmax)
...
@@ -788,7 +788,8 @@ TEST_CASE(argmax)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l0
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
2
,
0
},
l0
);
auto
ins
=
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
2
},
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
}},
ins
);
auto
prog
=
migraphx
::
parse_onnx
(
"argmax_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"argmax_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
@@ -798,7 +799,8 @@ TEST_CASE(argmin)
...
@@ -798,7 +799,8 @@ TEST_CASE(argmin)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
auto
l0
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
3
,
0
},
l0
);
auto
ins
=
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
3
},
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
3
}},
ins
);
auto
prog
=
migraphx
::
parse_onnx
(
"argmin_test.onnx"
);
auto
prog
=
migraphx
::
parse_onnx
(
"argmin_test.onnx"
);
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
...
...
test/op_shape_test.cpp
View file @
9ea01307
...
@@ -385,47 +385,27 @@ void test_argop_var()
...
@@ -385,47 +385,27 @@ void test_argop_var()
{
{
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
1
,
3
,
4
,
5
}},
T
{
0
,
1
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
1
,
3
,
4
,
5
}},
T
{
0
},
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
4
,
5
}},
T
{
1
,
1
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
4
,
5
}},
T
{
1
},
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
1
,
5
}},
T
{
2
,
1
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
1
,
5
}},
T
{
2
},
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
,
1
}},
T
{
3
,
1
},
input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
,
1
}},
T
{
3
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
3
,
4
,
5
}},
T
{
0
,
0
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
4
,
5
}},
T
{
1
,
0
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
5
}},
T
{
2
,
0
},
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
}},
T
{
3
,
0
},
input
);
}
}
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
throws_shape
(
T
{
4
,
1
},
input
);
throws_shape
(
T
{
4
},
input
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment