Commit e60aff63 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a pass for packing int8 op inputs

parent 5dc49b16
......@@ -84,6 +84,7 @@ add_library(migraphx_gpu
lrn.cpp
schedule_model.cpp
adjust_allocation.cpp
pack_int8_args.cpp
clip.cpp
reduce_sum.cpp
reduce_mean.cpp
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct pack_int8_args
{
std::string name() const { return "gpu::pack_int8_args"; }
void apply(program& p) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void pack_int8_args::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() == "gpu::quant_gemm")
{
auto inputs = ins->inputs();
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if (!transb)
{
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()});
auto output_b = p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b});
instruction::replace_argument(ins, inputs[1], output_b);
}
if (transa)
{
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()});
auto output_a = p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a});
instruction::replace_argument(ins, inputs[0], output_a);
}
}
else if (ins->name() == "gpu::quant_convolution")
{
}
}
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -21,6 +21,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
......@@ -62,6 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
adjust_allocation{},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment