Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f1c18355
Commit
f1c18355
authored
Sep 26, 2022
by
charlie
Browse files
Fixed using pack() correctly
parent
b76a9043
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
65 deletions
+125
-65
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+76
-30
tools/include/operation.hpp
tools/include/operation.hpp
+49
-35
No files found.
src/include/migraphx/operation.hpp
View file @
f1c18355
...
...
@@ -47,6 +47,45 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
context
;
/*
template <class F>
struct dyn_output
{
F ins_inputs;
dyn_output(F f) : ins_inputs(f){};
shape get_input_shape()
{
if(ins_shape.empty())
{
ins_shape = unpack(
[&](const auto&, shape s, const std::vector<argument>&) { return s; }, ins_inputs);
}
return ins_shape;
}
shape get_output_shape()
{
if(computed_shape.empty())
{
computed_shape = unpack(
[&](const auto& x, shape, const std::vector<argument>& inputs) {
return compute_shape(x, to_shapes(inputs));
},
ins_inputs);
}
return computed_shape;
}
private:
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
*/
struct
dyn_output
{
// original shape from the instruction
...
...
@@ -67,26 +106,22 @@ struct compute_output_shape
operator
dyn_output
()
const
{
return
unpack
(
[](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
},
ins_inputs
);
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
});
}
operator
shape
()
const
{
return
unpack
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
ins_inputs
);
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
}
};
template
<
class
T
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
inputs
)
->
decltype
(
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)})
template
<
class
F
>
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
{
return
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)
};
return
{
f
};
}
#ifdef DOXYGEN
...
...
@@ -243,10 +278,11 @@ auto compute_op(rank<1>,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
);
}
template
<
class
T
>
...
...
@@ -265,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
);
}
template
<
class
T
>
...
...
@@ -290,9 +326,10 @@ auto compute_op(rank<1>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
...
...
@@ -324,12 +361,17 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
...
...
@@ -340,9 +382,10 @@ auto compute_op(rank<3>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
...
...
@@ -352,9 +395,10 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
);
}
template
<
class
T
,
class
F
>
...
...
@@ -364,10 +408,12 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
);
}
template
<
class
T
,
class
F
>
...
...
tools/include/operation.hpp
View file @
f1c18355
...
...
@@ -47,6 +47,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
context
;
/*
template <class F>
struct dyn_output
{
F ins_inputs;
...
...
@@ -82,6 +84,15 @@ struct dyn_output
// shape computed at eval time using input arguments
shape computed_shape;
};
*/
struct
dyn_output
{
// original shape from the instruction
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
/**
* Handle dynamic and static shape at evaluation time.
...
...
@@ -93,32 +104,24 @@ struct compute_output_shape
{
F
ins_inputs
;
operator
dyn_output
<
F
>
()
const
operator
dyn_output
()
const
{
/*
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
{
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs
);
*/
return
dyn_output
<
F
>
{
ins_inputs
};
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
});
}
operator
shape
()
const
{
return
unpack
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
ins_inputs
);
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
}
};
template
<
class
T
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
inputs
)
->
decltype
(
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)})
template
<
class
F
>
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
{
return
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)
};
return
{
f
};
}
#ifdef DOXYGEN
...
...
@@ -275,10 +278,11 @@ auto compute_op(rank<1>,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
);
}
template
<
class
T
>
...
...
@@ -297,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
);
}
template
<
class
T
>
...
...
@@ -322,9 +326,10 @@ auto compute_op(rank<1>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
...
...
@@ -356,12 +361,17 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
...
...
@@ -372,9 +382,10 @@ auto compute_op(rank<3>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
template
<
class
T
,
class
F
>
...
...
@@ -384,9 +395,10 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
);
}
template
<
class
T
,
class
F
>
...
...
@@ -396,10 +408,12 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
compute_output_shape
<
T
>
{
x
,
output
,
inputs
},
inputs
))
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
);
}
template
<
class
T
,
class
F
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment