Commit 5cef60b8 authored by Alan Turner's avatar Alan Turner
Browse files

Cleanup

parent 2f2757ac
...@@ -112,28 +112,28 @@ void quantize_int8(program& prog, ...@@ -112,28 +112,28 @@ void quantize_int8(program& prog,
max_abs_vals->resize(param_num, 0.0f); max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale // use the calibration data to compute the quantization scale
// auto capture_prog = prog; auto capture_prog = prog;
// capture_prog.compile(t); capture_prog.compile(t);
// // use all calibration data to run the program to calculate the // use all calibration data to run the program to calculate the
// // quantization scale and shift // quantization scale and shift
// for(auto&& arg : calibration) for(auto&& arg : calibration)
// { {
// parameter_map m; parameter_map m;
// for(auto&& x : capture_prog.get_parameter_shapes()) for(auto&& x : capture_prog.get_parameter_shapes())
// { {
// if(arg.count(x.first) > 0) if(arg.count(x.first) > 0)
// { {
// assert(x.second == arg.at(x.first).get_shape()); assert(x.second == arg.at(x.first).get_shape());
// m[x.first] = t.copy_to(arg.at(x.first)); m[x.first] = t.copy_to(arg.at(x.first));
// } }
// else else
// { {
// m[x.first] = t.allocate(x.second); m[x.first] = t.allocate(x.second);
// } }
// } }
// capture_prog.eval(m); capture_prog.eval(m);
// } }
// print the quantization parameters in only the main module // print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
......
...@@ -96,8 +96,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -96,8 +96,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
// if(a.lens().back() > 2048) if(a.lens().back() > 2048)
// return false; return false;
return true; return true;
} }
...@@ -160,20 +160,6 @@ struct find_ck_gemm_pointwise_int8 ...@@ -160,20 +160,6 @@ struct find_ck_gemm_pointwise_int8
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto next_ins = std::next(ins); auto next_ins = std::next(ins);
// if (next_ins->name() == "quant_dot")
// {
// std::cout << "\nins: ";
// ins->debug_print();
// std::cout << "\ngemm_ins: ";
// gemm_ins->debug_print();
// std::cout << "\nx_ins: ";
// x_ins->debug_print();
// std::cout << "\nnext: ";
// next_ins->debug_print();
// mpm.get_module().debug_print();
// }
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
...@@ -182,13 +168,6 @@ struct find_ck_gemm_pointwise_int8 ...@@ -182,13 +168,6 @@ struct find_ck_gemm_pointwise_int8
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
// if(ins->get_shape().type() != shape::half_type)
// return;
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM before: " << std::endl;
// pm->debug_print();
// }
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
auto first_param = pm->get_parameter(names[0]); auto first_param = pm->get_parameter(names[0]);
...@@ -201,31 +180,9 @@ struct find_ck_gemm_pointwise_int8 ...@@ -201,31 +180,9 @@ struct find_ck_gemm_pointwise_int8
pm->remove_instruction(first_param); pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param); pm->remove_instruction(gemm_param);
} }
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM after: " << std::endl;
// pm->debug_print();
// }
inputs.erase(gemm_it); inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end()); inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
// std::cout << "Next_ins inputs: " << std::endl;
// for (auto& in : next_ins->inputs())
// {
// in->debug_print();
// }
// auto out_shape = compute_shape(ck_gemm_int8{}, inputs, {pm});
// instruction::replace(ins, ck_gemm_int8{}, out_shape.with_type(migraphx::shape::half_type), inputs, {pm});
mpm.get_module().replace_instruction(ins, ck_gemm_int8{}, inputs, {pm}); mpm.get_module().replace_instruction(ins, ck_gemm_int8{}, inputs, {pm});
// std::cout << "Next_ins inputs (post replace): " << std::endl;
// for (auto& in : std::next(ins)->inputs())
// {
// in->debug_print();
// }
// if (next_ins->name() == "softmax" or next_ins->name() == "reshape")
// {
// std::cout << "After replace: " << std::endl;
// mpm.get_module().debug_print();
// }
} }
}; };
......
...@@ -326,22 +326,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -326,22 +326,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std::array<std::size_t, 3> config{m, n, k}; std::array<std::size_t, 3> config{m, n, k};
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape.with_type(a_shape.type())})); auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape.with_type(a_shape.type())}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool { auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// if (not (get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9]))
// {
// std::cout << get_layout(a_shape) << " - " << x[0] <<std::endl;
// std::cout << get_layout(b_shape) << " - " << x[1] <<std::endl;
// std::cout << get_layout(c_shape) << " - " << x[3] <<std::endl;
// std::cout << get_type(a_shape) << " - " << x[4] <<std::endl;
// std::cout << get_type(b_shape) << " - " << x[5] <<std::endl;
// std::cout << get_type(c_shape) << " - " << x[9] <<std::endl;
// }
/* return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5]; get_type(b_shape) == x[5];
...@@ -354,15 +338,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -354,15 +338,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip.set_ds_op(v.at("post").to<std::string>()); ip.set_ds_op(v.at("post").to<std::string>());
} }
ip.set_e_type(get_type(c_shape)); if (a_shape.type() == shape::int8_type)
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "ck::half_t"; }))
{
ip.set_c_scalar_per_vec("8");
}
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "float"; }))
{ {
ip.set_c_scalar_per_vec("4"); ip.set_e_type(get_type(c_shape));
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "ck::half_t"; }))
{
ip.set_c_scalar_per_vec("8");
}
if (std::any_of(inputs.begin(), inputs.end(), [](auto s) { return get_type(s) == "float"; }))
{
ip.set_c_scalar_per_vec("4");
}
} }
auto padding = ip.get_pad(config); auto padding = ip.get_pad(config);
...@@ -407,7 +395,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -407,7 +395,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}}); {"kernel", options.kernel_name}});
// std::cout << options.kernel_name << ": " << std::endl;
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment