Unverified Commit 8b4c69c5 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Assert the shape for compute and compute_shape are the same (#936)



Assert shapes dont change
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
parent 6b6e9362
......@@ -193,6 +193,7 @@ std::vector<argument> generic_eval(const module* mod,
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod))
{
assert(results.find(ins) == results.end());
const auto& name = ins->name();
if(name == "@literal")
{
......@@ -250,6 +251,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
}
return {results.at(std::prev(mod->end()))};
}
......@@ -522,8 +524,9 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
// Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n);
return argument{};
return argument{ins->get_shape(), nullptr};
}));
// Run and time each instruction
for(std::size_t i = 0; i < n; i++)
{
......@@ -663,7 +666,9 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), always([](auto&&...) { return argument{}; }));
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
}
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......
File mode changed from 100644 to 100755
......@@ -5,10 +5,25 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape pack_int8_shape(const shape& s)
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0);
return pack_int8_shape(inputs.at(0));
}
argument
......
File mode changed from 100644 to 100755
......@@ -93,10 +93,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math},
dead_code_elimination{},
write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment