Commit cc30b7c1 authored by Alan Turner's avatar Alan Turner
Browse files

Add all layouts and make qdq use fp16 instaead of float

parent 734c2e74
...@@ -117,23 +117,23 @@ void quantize_int8(program& prog, ...@@ -117,23 +117,23 @@ void quantize_int8(program& prog,
// use all calibration data to run the program to calculate the // use all calibration data to run the program to calculate the
// quantization scale and shift // quantization scale and shift
for(auto&& arg : calibration) // for(auto&& arg : calibration)
{ // {
parameter_map m; // parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes()) // for(auto&& x : capture_prog.get_parameter_shapes())
{ // {
if(arg.count(x.first) > 0) // if(arg.count(x.first) > 0)
{ // {
assert(x.second == arg.at(x.first).get_shape()); // assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first)); // m[x.first] = t.copy_to(arg.at(x.first));
} // }
else // else
{ // {
m[x.first] = t.allocate(x.second); // m[x.first] = t.allocate(x.second);
} // }
} // }
capture_prog.eval(m); // capture_prog.eval(m);
} // }
// print the quantization parameters in only the main module // print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{})) if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
......
...@@ -40,7 +40,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -40,7 +40,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(x->get_shape().type() != y_scale->get_shape().type()) if(x->get_shape().type() != y_scale->get_shape().type())
{ {
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x); x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::half_type}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
...@@ -48,7 +48,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -48,7 +48,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto zero_point = m.insert_instruction( auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]); ins, make_op("convert", {{"target_type", shape::half_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
...@@ -73,13 +73,13 @@ void apply_dequantizelinear(module& m, instruction_ref ins) ...@@ -73,13 +73,13 @@ void apply_dequantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "dequantizelinear"); assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction( auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]); ins, make_op("convert", {{"target_type", shape::half_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1]; auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto x_zero_point = m.insert_instruction( auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]); ins, make_op("convert", {{"target_type", shape::half_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point); x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
} }
......
...@@ -60,8 +60,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -60,8 +60,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false; return false;
auto a = ins->inputs().front()->get_shape(); auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); auto b = ins->inputs().back()->get_shape();
if(a.lens().back() > 2048) // if(a.lens().back() > 2048)
return false; // return false;
return true; return true;
} }
...@@ -87,7 +87,7 @@ struct find_ck_gemm_pointwise ...@@ -87,7 +87,7 @@ struct find_ck_gemm_pointwise
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins); auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin(); auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(ins->get_shape().type() != shape::int8_type and ins->get_shape().type()) if(ins->get_shape().type() != shape::int8_type)
return; return;
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
......
...@@ -71,47 +71,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -71,47 +71,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
using GEMM = ck::tensor_operation::device::DeviceGemmMultipleD_Dl< using GEMM = ck::tensor_operation::device::${instance1}${padding}${instance2};
Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
int32_t,
Empty_Tuple,
int8_t, //EDataType
PassThrough,
PassThrough,
PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
256,
128,
128,
16,
4,
4,
4,
1,
S<8,2>,
S<8,2>,
S<8,1,1,4>,
S<2,1,128,1>,
S<1,2,0,3>,
S<1,2,0,3>,
S<4,1,1,4>,
S<1,2,0,3>,
S<1,1,1,4>,
S<2,1,4,4>,
S<8,1,32,1>,
S<0,3,1,2>,
S<0,3,1,2>,
S<1,1,4,1>,
S<0,3,1,2>,
S<1,1,4,4>,
S<0,1,2,3,4,5>,
5,
4>;
namespace migraphx { namespace migraphx {
...@@ -335,6 +295,34 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -335,6 +295,34 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto transa = transposed_matrix(a_shape);
auto transb = transposed_matrix(b_shape);
std::string instance_str1;
std::string instance_str2;
if (transa and not transb)
{
instance_str1 = "DeviceGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>";
}
else if (transa and transb)
{
instance_str1 = "DeviceGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>";
}
else if (not transa and not transb)
{
instance_str1 = "DeviceGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>";
}
else
{
instance_str1 = "DeviceGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>";
}
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
auto b_strides = b_shape.strides(); auto b_strides = b_shape.strides();
...@@ -361,19 +349,26 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -361,19 +349,26 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip.set_ds_op(v.at("post").to<std::string>()); ip.set_ds_op(v.at("post").to<std::string>());
} }
auto m_per_block = 128;
auto n_per_block = 128;
auto k_per_block = 16;
auto padding = ip.get_pad(config); auto padding = ip.get_pad(config);
std::string gemm_type; std::string gemm_type;
for(auto i : range(padding.size())) // if (int_div_ceil(m, m_per_block) * m_per_block - m != 0)
{ // gemm_type += "M";
if(padding[i] != 0) // if (int_div_ceil(n, n_per_block) * n_per_block - n != 0)
gemm_type += keys[i]; // gemm_type += "N";
} // if (int_div_ceil(k, k_per_block) * k_per_block - k != 0)
if(gemm_type.empty()) // gemm_type += "K";
gemm_type = "Default"; if ((int_div_ceil(m, m_per_block) * m_per_block - m != 0) or (int_div_ceil(n, n_per_block) * n_per_block - n != 0))
gemm_type = "MNPadding";
else else
gemm_type += "Padding"; gemm_type = "Default";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type); ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
std::string padding_str = "ck::tensor_operation::device::GemmSpecialization::" + gemm_type;
std::cout << padding_str << std::endl;
//std::exit(0);
auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128); auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128);
; // ip.get_grid_size(config); ; // ip.get_grid_size(config);
...@@ -402,7 +397,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -402,7 +397,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1"; options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel, auto src = interpolate_string(ck_gemm_kernel,
{{"instance", ip.str()}, {{"instance1", instance_str1},
{"instance2", instance_str2},
{"padding", padding_str},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment