Unverified Commit 8b4c69c5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Assert the shape for compute and compute_shape are the same (#936)



Assert shapes dont change
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
parent 6b6e9362
...@@ -193,6 +193,7 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -193,6 +193,7 @@ std::vector<argument> generic_eval(const module* mod,
auto trace = make_trace(mod); auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end());
const auto& name = ins->name(); const auto& name = ins->name();
if(name == "@literal") if(name == "@literal")
{ {
...@@ -250,6 +251,7 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -250,6 +251,7 @@ std::vector<argument> generic_eval(const module* mod,
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
...@@ -522,8 +524,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -522,8 +524,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
// Fill the map // Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) { generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{}; return argument{ins->get_shape(), nullptr};
})); }));
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
...@@ -663,7 +666,9 @@ void program::print_cpp(std::ostream& os) const ...@@ -663,7 +666,9 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), always([](auto&&...) { return argument{}; })); generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......
File mode changed from 100644 to 100755
...@@ -5,10 +5,25 @@ namespace migraphx { ...@@ -5,10 +5,25 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
shape pack_int8_shape(const shape& s)
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{{inputs.at(0)}, *this}.has(1).standard(); check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0); return pack_int8_shape(inputs.at(0));
} }
argument argument
......
File mode changed from 100644 to 100755
...@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment