Commit 13a8bcaa authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_check_shapes' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_conv

parents d5636acd d6afa9e9
...@@ -38,22 +38,34 @@ struct check_shapes ...@@ -38,22 +38,34 @@ struct check_shapes
const shape* begin; const shape* begin;
const shape* end; const shape* end;
const std::string name; const std::string name;
const bool dynamic_allowed;
check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n) check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
: begin(b), end(e), name(n), dynamic_allowed(d)
{ {
} }
template <class Op> template <class Op>
check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name()) check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false)
: begin(b), end(e), name(op.name()), dynamic_allowed(d)
{ {
} }
template <class Op> template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.data()), end(s.data() + s.size()), name(op.name()) : begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d)
{ {
} }
~check_shapes()
{
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
{
std::cerr << prefix() << "Dynamic shapes not supported" << std::endl;
std::abort();
}
}
std::string prefix() const std::string prefix() const
{ {
if(name.empty()) if(name.empty())
...@@ -92,6 +104,11 @@ struct check_shapes ...@@ -92,6 +104,11 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has exactly n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& only_dims(std::size_t n) const const check_shapes& only_dims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
...@@ -104,6 +121,11 @@ struct check_shapes ...@@ -104,6 +121,11 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has a maximum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& max_ndims(std::size_t n) const const check_shapes& max_ndims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
...@@ -117,6 +139,11 @@ struct check_shapes ...@@ -117,6 +139,11 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check that the first shape has a minimum of n dimensions.
* Do nothing if the container is empty.
* \param n number of dimensions
*/
const check_shapes& min_ndims(std::size_t n) const const check_shapes& min_ndims(std::size_t n) const
{ {
assert(begin != nullptr); assert(begin != nullptr);
...@@ -130,6 +157,9 @@ struct check_shapes ...@@ -130,6 +157,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same shape.
*/
const check_shapes& same_shape() const const check_shapes& same_shape() const
{ {
if(!this->same([](const shape& s) { return s; })) if(!this->same([](const shape& s) { return s; }))
...@@ -137,6 +167,9 @@ struct check_shapes ...@@ -137,6 +167,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same type.
*/
const check_shapes& same_type() const const check_shapes& same_type() const
{ {
if(!this->same([](const shape& s) { return s.type(); })) if(!this->same([](const shape& s) { return s.type(); }))
...@@ -144,6 +177,9 @@ struct check_shapes ...@@ -144,6 +177,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same lens.
*/
const check_shapes& same_dims() const const check_shapes& same_dims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens(); })) if(!this->same([](const shape& s) { return s.max_lens(); }))
...@@ -151,6 +187,9 @@ struct check_shapes ...@@ -151,6 +187,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same number of dimensions.
*/
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(!this->same([](const shape& s) { return s.max_lens().size(); })) if(!this->same([](const shape& s) { return s.max_lens().size(); }))
...@@ -158,6 +197,9 @@ struct check_shapes ...@@ -158,6 +197,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are standard.
*/
const check_shapes& standard() const const check_shapes& standard() const
{ {
if(!this->all_of([](const shape& s) { return s.standard(); })) if(!this->all_of([](const shape& s) { return s.standard(); }))
...@@ -165,6 +207,9 @@ struct check_shapes ...@@ -165,6 +207,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are standard or scalar.
*/
const check_shapes& standard_or_scalar() const const check_shapes& standard_or_scalar() const
{ {
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); })) if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
...@@ -172,6 +217,9 @@ struct check_shapes ...@@ -172,6 +217,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed.
*/
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(!this->all_of([](const shape& s) { return s.packed(); }))
...@@ -179,6 +227,9 @@ struct check_shapes ...@@ -179,6 +227,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are packed or broadcasted.
*/
const check_shapes& packed_or_broadcasted() const const check_shapes& packed_or_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); })) if(!this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
...@@ -186,6 +237,9 @@ struct check_shapes ...@@ -186,6 +237,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are tuples.
*/
const check_shapes& tuple_type() const const check_shapes& tuple_type() const
{ {
if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; })) if(!this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
...@@ -193,6 +247,9 @@ struct check_shapes ...@@ -193,6 +247,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are not transposed.
*/
const check_shapes& not_transposed() const const check_shapes& not_transposed() const
{ {
if(!this->all_of([](const shape& s) { return not s.transposed(); })) if(!this->all_of([](const shape& s) { return not s.transposed(); }))
...@@ -200,6 +257,9 @@ struct check_shapes ...@@ -200,6 +257,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are not broadcasted.
*/
const check_shapes& not_broadcasted() const const check_shapes& not_broadcasted() const
{ {
if(!this->all_of([](const shape& s) { return not s.broadcasted(); })) if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
...@@ -207,6 +267,10 @@ struct check_shapes ...@@ -207,6 +267,10 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same n elements.
* \param n number of elements
*/
const check_shapes& elements(std::size_t n) const const check_shapes& elements(std::size_t n) const
{ {
if(!this->all_of([&](const shape& s) { return s.elements() == n; })) if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
...@@ -214,6 +278,9 @@ struct check_shapes ...@@ -214,6 +278,9 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check the batches of all the shapes do not have transposed strides.
*/
const check_shapes& batch_not_transposed() const const check_shapes& batch_not_transposed() const
{ {
if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); })) if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
...@@ -242,6 +309,16 @@ struct check_shapes ...@@ -242,6 +309,16 @@ struct check_shapes
return std::all_of(begin, end, p); return std::all_of(begin, end, p);
} }
template <class Predicate>
bool any_of(Predicate p) const
{
if(begin == end)
return false;
assert(begin != nullptr);
assert(end != nullptr);
return std::any_of(begin, end, p);
}
const shape* get(long i) const const shape* get(long i) const
{ {
if(i >= size()) if(i >= size())
......
...@@ -740,11 +740,13 @@ void program::perf_report(std::ostream& os, ...@@ -740,11 +740,13 @@ void program::perf_report(std::ostream& os,
double overhead_percent = overhead_time * 100.0 / total_time; double overhead_percent = overhead_time * 100.0 / total_time;
double total_instruction_time = 0.0; double total_instruction_time = 0.0;
std::unordered_map<std::string, double> op_times; std::unordered_map<std::string, double> op_times;
std::unordered_map<std::string, std::size_t> op_n;
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
{ {
double avg = common_average(p.second); double avg = common_average(p.second);
op_times[perf_group(p.first->get_operator())] += avg; op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg; total_instruction_time += avg;
op_n[perf_group(p.first->get_operator())]++;
} }
double calculate_overhead_time = total_time - total_instruction_time; double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time; double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
...@@ -765,18 +767,19 @@ void program::perf_report(std::ostream& os, ...@@ -765,18 +767,19 @@ void program::perf_report(std::ostream& os,
os << std::endl; os << std::endl;
os << "Summary:" << std::endl; os << "Summary:" << std::endl;
std::vector<std::pair<double, std::string>> op_times_sorted; std::vector<std::tuple<double, std::size_t, std::string>> op_times_sorted;
std::transform(op_times.begin(), std::transform(
op_times.end(), op_times.begin(), op_times.end(), std::back_inserter(op_times_sorted), [&](auto p) {
std::back_inserter(op_times_sorted), auto&& name = p.first;
[](auto p) { return std::make_pair(p.second, p.first); }); return std::make_tuple(p.second, op_n.at(name), name);
});
std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{}); std::sort(op_times_sorted.begin(), op_times_sorted.end(), std::greater<>{});
for(auto&& p : op_times_sorted) for(auto&& [avg, nn, name] : op_times_sorted)
{ {
auto&& name = p.second;
double avg = p.first;
double percent = std::ceil(100.0 * avg / total_instruction_time); double percent = std::ceil(100.0 * avg / total_instruction_time);
os << name << ": " << avg << "ms, " << percent << "%" << std::endl; double per_ins = avg / nn;
os << name << ": " << avg << "ms / " << nn << " = " << per_ins << "ms, " << percent << "%"
<< std::endl;
} }
os << std::endl; os << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment