Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4351f46c
Unverified
Commit
4351f46c
authored
Jun 27, 2021
by
Shucai Xiao
Committed by
GitHub
Jun 27, 2021
Browse files
Merge branch 'develop' into scatter-op
parents
27c0ae08
bc52a8a8
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
100 additions
and
31 deletions
+100
-31
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+6
-3
test/verify/test_if_literal.cpp
test/verify/test_if_literal.cpp
+2
-1
test/verify/test_if_lp.cpp
test/verify/test_if_lp.cpp
+3
-1
test/verify/test_if_param.cpp
test/verify/test_if_param.cpp
+2
-1
tools/include/operation.hpp
tools/include/operation.hpp
+87
-25
No files found.
test/ref_ops_test.cpp
View file @
4351f46c
...
...
@@ -1668,7 +1668,8 @@ TEST_CASE(if_literal_test)
else_mod
->
add_return
({
l2
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
mm
->
add_return
({
ret
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
ret
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -1730,7 +1731,8 @@ TEST_CASE(if_param_test)
else_mod
->
add_return
({
a2
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
,
x
,
y
},
{
then_mod
,
else_mod
});
mm
->
add_return
({
ret
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
ret
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -1796,7 +1798,8 @@ TEST_CASE(if_pl_test)
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
auto
outline
=
mm
->
add_outline
(
s
);
mm
->
add_return
({
outline
,
ret
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
ret
);
mm
->
add_return
({
outline
,
r
});
return
p
;
};
...
...
test/verify/test_if_literal.cpp
View file @
4351f46c
...
...
@@ -26,7 +26,8 @@ struct test_if_literal : verify_program<test_if_literal>
else_mod
->
add_return
({
l2
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
mm
->
add_return
({
ret
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
ret
);
mm
->
add_return
({
r
});
return
p
;
}
...
...
test/verify/test_if_lp.cpp
View file @
4351f46c
...
...
@@ -27,7 +27,9 @@ struct test_if_lp : verify_program<test_if_lp>
else_mod
->
add_return
({
s2
,
l2
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
mm
->
add_return
({
ret
});
auto
r0
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
ret
);
auto
r1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
ret
);
mm
->
add_return
({
r0
,
r1
});
return
p
;
}
...
...
test/verify/test_if_param.cpp
View file @
4351f46c
...
...
@@ -29,7 +29,8 @@ struct test_if_param : verify_program<test_if_param>
else_mod
->
add_return
({
a2
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
mm
->
add_return
({
ret
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
ret
);
mm
->
add_return
({
r
});
return
p
;
}
...
...
tools/include/operation.hpp
View file @
4351f46c
...
...
@@ -178,7 +178,7 @@ shape normalize_compute_shape_op(const T& x,
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
...
...
@@ -188,14 +188,6 @@ auto compute_op(rank<2>,
return
x
.
compute
(
auto_any_cast
(
ctx
),
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
...
...
@@ -207,50 +199,106 @@ template <class T>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
ctx
,
output_shape
,
input
);
return
compute_op
(
rank
<
1
>
{},
x
,
ctx
,
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
output_shape
,
input
))
{
return
x
.
compute
(
output_shape
,
input
);
}
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
std
::
declval
<
context
&>
()),
output_shape
,
input
))
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable
without a context
: "
+
name
);
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
)
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
{
return
compute_op
(
rank
<
1
>
{},
x
,
output_shape
,
input
);
}
template
<
class
T
,
class
F
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
shape
&
,
const
std
::
vector
<
argument
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
}
template
<
class
T
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
template
<
class
T
,
class
F
>
argument
compute_op
(
const
T
&
x
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
{
return
compute_op
(
rank
<
2
>
{},
x
,
output
_shape
,
input
);
return
compute_op
(
rank
<
1
>
{},
x
,
output
,
input
s
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
auto
compute_op
(
rank
<
1
>
,
auto
compute_op
(
rank
<
3
>
,
const
T
&
x
,
context
&
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
output
,
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
inputs
,
module_args
,
f
);
return
x
.
compute
(
output
,
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
const
std
::
vector
<
argument
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
auto
compute_op
(
rank
<
2
>
,
const
T
&
x
,
context
&
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
output
,
inputs
))
{
return
x
.
compute
(
output
,
inputs
);
}
template
<
class
T
,
class
F
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
output
,
inputs
);
}
template
<
class
T
,
class
F
>
argument
compute_op
(
rank
<
0
>
,
const
T
&
x
,
context
&
,
const
shape
&
,
const
std
::
vector
<
argument
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
{
std
::
string
name
=
x
.
name
();
MIGRAPHX_THROW
(
"Not computable: "
+
name
);
...
...
@@ -258,11 +306,13 @@ argument
template
<
class
T
,
class
F
>
argument
compute_op
(
const
T
&
x
,
context
&
ctx
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
{
return
compute_op
(
rank
<
1
>
{},
x
,
inputs
,
module_args
,
f
);
return
compute_op
(
rank
<
3
>
{},
x
,
ctx
,
output
,
inputs
,
module_args
,
f
);
}
template
<
class
T
>
...
...
@@ -447,10 +497,22 @@ bool is_borrowed_op(const T&)
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
module_args
=
'
const
std
::
vector
<
module_ref
>&
'
,
run
=
'
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>
'
,
const
=
True
,
default
=
'
detail
::
compute_op
'
),
virtual
(
'
compute
'
,
returns
=
'
argument
'
,
ctx
=
'
context
&
'
,
output
=
'
const
shape
&
'
,
input
=
'
const
std
::
vector
<
argument
>&
'
,
module_args
=
'
const
std
::
vector
<
module_ref
>&
'
,
run
=
'
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
mdl
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
inputs
)
>
'
,
'
std
::
function
<
std
::
vector
<
argument
>
(
module_ref
&
,
const
std
::
unordered_map
<
std
::
string
,
argument
>&
)
>
'
,
const
=
True
,
default
=
'
detail
::
compute_op
'
),
virtual
(
'
to_value
'
,
returns
=
'
value
'
,
const
=
True
,
default
=
'
detail
::
to_value_op
'
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment