Commit 99626b4c authored by Alan Turner's avatar Alan Turner
Browse files

Enable int8 gemm-pointwise fusion

parent 3d0426e9
......@@ -73,7 +73,7 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {shape::int32_type, out_lens};
return {shape::int8_type, out_lens};
}
};
......
......@@ -112,28 +112,28 @@ void quantize_int8(program& prog,
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// auto capture_prog = prog;
// capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
// // use all calibration data to run the program to calculate the
// // quantization scale and shift
// for(auto&& arg : calibration)
// {
// parameter_map m;
// for(auto&& x : capture_prog.get_parameter_shapes())
// {
// if(arg.count(x.first) > 0)
// {
// assert(x.second == arg.at(x.first).get_shape());
// m[x.first] = t.copy_to(arg.at(x.first));
// }
// else
// {
// m[x.first] = t.allocate(x.second);
// }
// }
// capture_prog.eval(m);
// }
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
......
......@@ -37,10 +37,11 @@ void apply_quantizelinear(module& m, instruction_ref ins)
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
auto scale_type = y_scale->get_shape().type();
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
x = m.insert_instruction(ins, make_op("convert", {{"target_type", scale_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
......@@ -48,7 +49,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
ins, make_op("convert", {{"target_type", scale_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
......@@ -72,14 +73,15 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
auto scale_type = x_scale->get_shape().type();
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", scale_type}}), ins->inputs()[0]);
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
ins, make_op("convert", {{"target_type", scale_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
......@@ -100,6 +102,8 @@ void rewrite_quantization::apply(module& m) const
apply_dequantizelinear(m, ins);
}
}
// std::cout << "after rwq: " << std::endl;
// m.debug_print();
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -4,7 +4,6 @@
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -52,16 +51,53 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_int8
{
operation op = make_op("quant_dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_int8"; }
void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto a = inputs[0];
auto b = inputs[1];
for(const auto& input : inputs)
check_gemm_shape(input);
auto r = op.compute_shape({a, b});
if(mods.empty())
return r.with_type(migraphx::shape::int8_type);
return r.with_type(mods.front()->get_output_shapes().front().type());
}
};
MIGRAPHX_REGISTER_OP(ck_gemm_int8);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048)
return false;
// if(a.lens().back() > 2048)
// return false;
return true;
}
......@@ -93,9 +129,9 @@ struct find_ck_gemm_pointwise
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
......@@ -108,6 +144,91 @@ struct find_ck_gemm_pointwise
}
};
struct find_ck_gemm_pointwise_int8
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
{
auto gemm =
match::skip(match::name("contiguous"))(match::name("quant_dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto next_ins = std::next(ins);
// if (next_ins->name() == "quant_dot")
// {
// std::cout << "\nins: ";
// ins->debug_print();
// std::cout << "\ngemm_ins: ";
// gemm_ins->debug_print();
// std::cout << "\nx_ins: ";
// x_ins->debug_print();
// std::cout << "\nnext: ";
// next_ins->debug_print();
// mpm.get_module().debug_print();
// }
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
// if(ins->get_shape().type() != shape::half_type)
// return;
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM before: " << std::endl;
// pm->debug_print();
// }
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + "_0", gemm_param->get_shape());
auto new_first_param =
pm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM after: " << std::endl;
// pm->debug_print();
// }
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
// std::cout << "Next_ins inputs: " << std::endl;
// for (auto& in : next_ins->inputs())
// {
// in->debug_print();
// }
// auto out_shape = compute_shape(ck_gemm_int8{}, inputs, {pm});
// instruction::replace(ins, ck_gemm_int8{}, out_shape.with_type(migraphx::shape::half_type), inputs, {pm});
mpm.get_module().replace_instruction(ins, ck_gemm_int8{}, inputs, {pm});
// std::cout << "Next_ins inputs (post replace): " << std::endl;
// for (auto& in : std::next(ins)->inputs())
// {
// in->debug_print();
// }
// if (next_ins->name() == "softmax" or next_ins->name() == "reshape")
// {
// std::cout << "After replace: " << std::endl;
// mpm.get_module().debug_print();
// }
}
};
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
......@@ -119,14 +240,31 @@ struct find_ck_gemm
}
};
struct find_ck_gemm_int8
{
auto matcher() const { return match::name("quant_dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm_int8{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_pointwise_int8{});
}
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
{
match::find_matches(mpm, find_ck_gemm{});
match::find_matches(mpm, find_ck_gemm_int8{});
}
}
} // namespace gpu
......
......@@ -122,6 +122,12 @@ struct instance
params[8] = s;
}
void set_e_type(const std::string& s)
{
//assert(params[9] == "ck::Tuple<>");
params[9] = s;
}
void set_ds_op(const std::string& s)
{
assert(params[12] == "ck_passthrough");
......@@ -134,6 +140,23 @@ struct instance
params[13] = s;
}
void set_a_scalar_per_vec(const std::string& s)
{
params[block_size_index + 14] = s;
params[block_size_index + 15] = s;
}
void set_b_scalar_per_vec(const std::string& s)
{
params[block_size_index + 20] = s;
params[block_size_index + 21] = s;
}
void set_c_scalar_per_vec(const std::string& s)
{
params[params.size() - 3] = s;
}
std::string str() const { return join_strings(params, ","); }
};
......@@ -175,12 +198,20 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
if(tuning.empty())
std::cout << "*********** Warning: No CK tuning!" << std::endl;
{
std::cout << "*********** Warning: No CK tuning! for config:" << std::endl;
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
}
auto it = std::find_if(
tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
if(it == tuning.end())
{
std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
std::cout << " " << inputs[0] << std::endl;
std::cout << " " << inputs[1] << std::endl;
std::cout << " " << inputs[2] << std::endl;
std::vector<std::pair<float, std::size_t>> w;
std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
if(inputs.size() < 3 or p.first.size() < 3)
......@@ -274,7 +305,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s = shape{s.type(), {m1, m2}};
}
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm", "ck_gemm_int8", "gpu::ck_gemm_int8"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{
......@@ -293,11 +324,27 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto k = a_shape.lens().back();
std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{m, n, k};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape.with_type(a_shape.type())}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// if (not (get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9]))
// {
// std::cout << get_layout(a_shape) << " - " << x[0] <<std::endl;
// std::cout << get_layout(b_shape) << " - " << x[1] <<std::endl;
// std::cout << get_layout(c_shape) << " - " << x[3] <<std::endl;
// std::cout << get_type(a_shape) << " - " << x[4] <<std::endl;
// std::cout << get_type(b_shape) << " - " << x[5] <<std::endl;
// std::cout << get_type(c_shape) << " - " << x[9] <<std::endl;
// }
/* return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
get_type(b_shape) == x[5];
})};
assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post"))
......@@ -305,7 +352,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
ip.set_ds_op(v.at("post").to<std::string>());
}
ip.set_e_type(get_type(c_shape));
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "ck::half_t"; }))
{
ip.set_c_scalar_per_vec("8");
}
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "float"; }))
{
ip.set_c_scalar_per_vec("4");
}
auto padding = ip.get_pad(config);
std::string gemm_type;
......@@ -349,7 +407,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
// std::cout << options.kernel_name << ": " << std::endl;
return compile_hip_code_object(src, options);
}
......@@ -370,7 +428,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
{
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
}
});
......
This diff is collapsed.
......@@ -30,21 +30,26 @@ def parse_args():
type=str,
help=
'Existing tuning JSON. Configs already present will not be re-tuned.')
parser.add_argument("-q", "--quantize_int8", action="store_true")
args = parser.parse_args()
return args
def tune_models(models, batch_sizes, seq_len, n, existing):
def tune_models(models, batch_sizes, seq_len, n, existing, q_int8):
time_stamp = time.strftime("%Y_%m_%d_%H_%M")
log_file = "ck_tuning_{}.log".format(time_stamp)
json_file = "ck_tuning_{}.json".format(time_stamp)
prec_str = "--int8" if q_int8 else ""
for model in models:
for batch in batch_sizes:
params = "--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16 ".format(
batch, batch)
params = "--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16 {} ".format(
batch, batch, prec_str)
if "bert" in model:
params = "--fill1 input_ids --input-dim @input_ids {} {} ".format(
batch, seq_len)
params = "{} --fp16 --fill1 input_ids --input-dim @input_ids {} {} ".format(
prec_str, batch, seq_len)
if "squad" in model:
params = "--fill1 input_ids:0 unique_ids_raw_output___9:0 input_mask:0 segment_ids:0 --input-dim @input_ids:0 {} 256 @input_mask:0 {} 256 @segment_ids:0 {} 256 --fp16 {}".format(
batch, batch, batch, prec_str)
out = subprocess.run(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'
.format(model, params, log_file),
......@@ -96,7 +101,7 @@ def tune_models(models, batch_sizes, seq_len, n, existing):
def run(args):
tune_models(args.models, args.batch_sizes, args.sequence_length, args.n,
args.update)
args.update, args.quantize_int8)
run(parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment