Commit 5cef60b8 authored by Alan Turner's avatar Alan Turner
Browse files

Cleanup

parent 2f2757ac
......@@ -112,28 +112,28 @@ void quantize_int8(program& prog,
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
// auto capture_prog = prog;
// capture_prog.compile(t);
auto capture_prog = prog;
capture_prog.compile(t);
// // use all calibration data to run the program to calculate the
// // quantization scale and shift
// for(auto&& arg : calibration)
// {
// parameter_map m;
// for(auto&& x : capture_prog.get_parameter_shapes())
// {
// if(arg.count(x.first) > 0)
// {
// assert(x.second == arg.at(x.first).get_shape());
// m[x.first] = t.copy_to(arg.at(x.first));
// }
// else
// {
// m[x.first] = t.allocate(x.second);
// }
// }
// capture_prog.eval(m);
// }
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
......
......@@ -96,8 +96,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// if(a.lens().back() > 2048)
// return false;
if(a.lens().back() > 2048)
return false;
return true;
}
......@@ -160,20 +160,6 @@ struct find_ck_gemm_pointwise_int8
auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto next_ins = std::next(ins);
// if (next_ins->name() == "quant_dot")
// {
// std::cout << "\nins: ";
// ins->debug_print();
// std::cout << "\ngemm_ins: ";
// gemm_ins->debug_print();
// std::cout << "\nx_ins: ";
// x_ins->debug_print();
// std::cout << "\nnext: ";
// next_ins->debug_print();
// mpm.get_module().debug_print();
// }
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
......@@ -182,13 +168,6 @@ struct find_ck_gemm_pointwise_int8
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
// if(ins->get_shape().type() != shape::half_type)
// return;
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM before: " << std::endl;
// pm->debug_print();
// }
if(gemm_idx != 0)
{
auto first_param = pm->get_parameter(names[0]);
......@@ -201,31 +180,9 @@ struct find_ck_gemm_pointwise_int8
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM after: " << std::endl;
// pm->debug_print();
// }
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
// std::cout << "Next_ins inputs: " << std::endl;
// for (auto& in : next_ins->inputs())
// {
// in->debug_print();
// }
// auto out_shape = compute_shape(ck_gemm_int8{}, inputs, {pm});
// instruction::replace(ins, ck_gemm_int8{}, out_shape.with_type(migraphx::shape::half_type), inputs, {pm});
mpm.get_module().replace_instruction(ins, ck_gemm_int8{}, inputs, {pm});
// std::cout << "Next_ins inputs (post replace): " << std::endl;
// for (auto& in : std::next(ins)->inputs())
// {
// in->debug_print();
// }
// if (next_ins->name() == "softmax" or next_ins->name() == "reshape")
// {
// std::cout << "After replace: " << std::endl;
// mpm.get_module().debug_print();
// }
}
};
......
......@@ -326,22 +326,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::array<std::size_t, 3> config{m, n, k};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape.with_type(a_shape.type())}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// if (not (get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9]))
// {
// std::cout << get_layout(a_shape) << " - " << x[0] <<std::endl;
// std::cout << get_layout(b_shape) << " - " << x[1] <<std::endl;
// std::cout << get_layout(c_shape) << " - " << x[3] <<std::endl;
// std::cout << get_type(a_shape) << " - " << x[4] <<std::endl;
// std::cout << get_type(b_shape) << " - " << x[5] <<std::endl;
// std::cout << get_type(c_shape) << " - " << x[9] <<std::endl;
// }
/* return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5];
......@@ -354,15 +338,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip.set_ds_op(v.at("post").to<std::string>());
}
ip.set_e_type(get_type(c_shape));
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "ck::half_t"; }))
{
ip.set_c_scalar_per_vec("8");
}
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "float"; }))
if (a_shape.type() == shape::int8_type)
{
ip.set_c_scalar_per_vec("4");
ip.set_e_type(get_type(c_shape));
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "ck::half_t"; }))
{
ip.set_c_scalar_per_vec("8");
}
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "float"; }))
{
ip.set_c_scalar_per_vec("4");
}
}
auto padding = ip.get_pad(config);
......@@ -407,7 +395,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
// std::cout << options.kernel_name << ": " << std::endl;
return compile_hip_code_object(src, options);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment