Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f1c18355
Commit
f1c18355
authored
Sep 26, 2022
by
charlie
Browse files
Fixed using pack() correctly
parent
b76a9043
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
65 deletions
+125
-65
src/include/migraphx/operation.hpp
src/include/migraphx/operation.hpp
+76
-30
tools/include/operation.hpp
tools/include/operation.hpp
+49
-35
No files found.
src/include/migraphx/operation.hpp
View file @
f1c18355
...
@@ -47,6 +47,45 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -47,6 +47,45 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
context
;
struct
context
;
/*
template <class F>
struct dyn_output
{
F ins_inputs;
dyn_output(F f) : ins_inputs(f){};
shape get_input_shape()
{
if(ins_shape.empty())
{
ins_shape = unpack(
[&](const auto&, shape s, const std::vector<argument>&) { return s; }, ins_inputs);
}
return ins_shape;
}
shape get_output_shape()
{
if(computed_shape.empty())
{
computed_shape = unpack(
[&](const auto& x, shape, const std::vector<argument>& inputs) {
return compute_shape(x, to_shapes(inputs));
},
ins_inputs);
}
return computed_shape;
}
private:
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
*/
struct
dyn_output
struct
dyn_output
{
{
// original shape from the instruction
// original shape from the instruction
...
@@ -67,26 +106,22 @@ struct compute_output_shape
...
@@ -67,26 +106,22 @@ struct compute_output_shape
operator
dyn_output
()
const
operator
dyn_output
()
const
{
{
return
unpack
(
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
[](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
});
},
ins_inputs
);
}
}
operator
shape
()
const
operator
shape
()
const
{
{
return
unpack
(
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
ins_inputs
);
}
}
};
};
template
<
class
T
>
template
<
class
F
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
inputs
)
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
->
decltype
(
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)})
{
{
return
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)
};
return
{
f
};
}
}
#ifdef DOXYGEN
#ifdef DOXYGEN
...
@@ -243,10 +278,11 @@ auto compute_op(rank<1>,
...
@@ -243,10 +278,11 @@ auto compute_op(rank<1>,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
input
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -265,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -265,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -290,9 +326,10 @@ auto compute_op(rank<1>,
...
@@ -290,9 +326,10 @@ auto compute_op(rank<1>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -324,12 +361,17 @@ auto compute_op(rank<4>,
...
@@ -324,12 +361,17 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
return
x
.
compute
(
auto_any_cast
(
ctx
),
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -340,9 +382,10 @@ auto compute_op(rank<3>,
...
@@ -340,9 +382,10 @@ auto compute_op(rank<3>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -352,9 +395,10 @@ auto compute_op(rank<2>,
...
@@ -352,9 +395,10 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -364,10 +408,12 @@ auto compute_op(rank<1>,
...
@@ -364,10 +408,12 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
...
tools/include/operation.hpp
View file @
f1c18355
...
@@ -47,6 +47,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -47,6 +47,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
context
;
struct
context
;
/*
template <class F>
struct dyn_output
struct dyn_output
{
{
F ins_inputs;
F ins_inputs;
...
@@ -82,6 +84,15 @@ struct dyn_output
...
@@ -82,6 +84,15 @@ struct dyn_output
// shape computed at eval time using input arguments
// shape computed at eval time using input arguments
shape computed_shape;
shape computed_shape;
};
};
*/
struct
dyn_output
{
// original shape from the instruction
shape
ins_shape
;
// shape computed at eval time using input arguments
shape
computed_shape
;
};
/**
/**
* Handle dynamic and static shape at evaluation time.
* Handle dynamic and static shape at evaluation time.
...
@@ -93,32 +104,24 @@ struct compute_output_shape
...
@@ -93,32 +104,24 @@ struct compute_output_shape
{
{
F
ins_inputs
;
F
ins_inputs
;
operator
dyn_output
<
F
>
()
const
operator
dyn_output
()
const
{
{
/*
return
ins_inputs
([](
const
auto
&
x
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
inputs
)
{
return unpack([](const auto& x, shape ins_shape, const std::vector<argument>& inputs)
return
dyn_output
{
ins_shape
,
compute_shape
(
x
,
to_shapes
(
inputs
))};
{
});
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
},
ins_inputs
);
*/
return
dyn_output
<
F
>
{
ins_inputs
};
}
}
operator
shape
()
const
operator
shape
()
const
{
{
return
unpack
(
return
ins_inputs
(
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
},
[](
const
auto
&
,
shape
ins_shape
,
const
std
::
vector
<
argument
>&
)
{
return
ins_shape
;
});
ins_inputs
);
}
}
};
};
template
<
class
T
>
template
<
class
F
>
auto
make_compute_output_shape
(
const
T
&
x
,
shape
output_shape
,
const
std
::
vector
<
argument
>&
inputs
)
compute_output_shape
<
F
>
make_compute_output_shape
(
F
f
)
->
decltype
(
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)})
{
{
return
compute_output_shape
{
pack
(
x
,
output_shape
,
inputs
)
};
return
{
f
};
}
}
#ifdef DOXYGEN
#ifdef DOXYGEN
...
@@ -275,10 +278,11 @@ auto compute_op(rank<1>,
...
@@ -275,10 +278,11 @@ auto compute_op(rank<1>,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
input
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)),
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -297,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
...
@@ -297,9 +301,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template
<
class
T
>
template
<
class
T
>
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
auto
compute_op
(
rank
<
1
>
,
const
T
&
x
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
input
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output_shape
,
input
),
input
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output_shape
,
input
)
)
,
input
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -322,9 +326,10 @@ auto compute_op(rank<1>,
...
@@ -322,9 +326,10 @@ auto compute_op(rank<1>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -356,12 +361,17 @@ auto compute_op(rank<4>,
...
@@ -356,12 +361,17 @@ auto compute_op(rank<4>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
F
f
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
return
x
.
compute
(
auto_any_cast
(
ctx
),
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -372,9 +382,10 @@ auto compute_op(rank<3>,
...
@@ -372,9 +382,10 @@ auto compute_op(rank<3>,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
module_args
,
const
std
::
vector
<
module_ref
>&
module_args
,
F
f
)
F
f
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
))
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
,
module_args
,
f
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
,
module_args
,
f
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
,
module_args
,
f
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -384,9 +395,10 @@ auto compute_op(rank<2>,
...
@@ -384,9 +395,10 @@ auto compute_op(rank<2>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
))
F
)
->
decltype
(
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
{
return
x
.
compute
(
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)
)
,
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
@@ -396,10 +408,12 @@ auto compute_op(rank<1>,
...
@@ -396,10 +408,12 @@ auto compute_op(rank<1>,
const
shape
&
output
,
const
shape
&
output
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
module_ref
>&
,
const
std
::
vector
<
module_ref
>&
,
F
)
F
)
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
->
decltype
(
x
.
compute
(
auto_any_cast
(
ctx
),
compute_output_shape
<
T
>
{
x
,
output
,
inputs
},
inputs
))
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
))
{
{
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
x
,
output
,
inputs
),
inputs
);
return
x
.
compute
(
auto_any_cast
(
ctx
),
make_compute_output_shape
(
pack
(
x
,
output
,
inputs
)),
inputs
);
}
}
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment