Commit d9d47733 authored by jerryyin's avatar jerryyin
Browse files

tmp

parent 27297b0b
......@@ -80,17 +80,17 @@ static void create_pointwise_modules(module_pass_manager& mpm)
if(contains(param_map, input))
continue;
auto scalar = get_scalar(input);
if(scalar.empty())
{
//if(scalar.empty())
//{
pointwise_inputs.push_back(input);
param_map[input] =
pm->add_parameter("x" + std::to_string(i), shape{input->get_shape().type()});
i++;
}
else
{
param_map[input] = pm->add_literal(scalar);
}
//}
//else
//{
// param_map[input] = pm->add_literal(scalar);
//}
}
// Don't create pointwise module if no inputs are detected
......
......@@ -202,7 +202,15 @@ struct check_shapes
*/
const check_shapes& same_dims() const
{
if(not this->same([](const shape& s) { return s.max_lens(); }))
if(not this->same([](const shape& s) {
auto print = [](std::vector<std::size_t> lens) {
for (auto i : lens)
std::cout << i << " ";
};
std::cout << "s.lens() = ";
print(s.max_lens());
std::cout << std::endl;
return s.max_lens(); }))
MIGRAPHX_THROW(prefix() + "Dimensions do not match");
if(this->any_of([&](const shape& s) { return s.dynamic(); }))
if(not this->same([](const shape& s) { return s.min_lens(); }))
......
......@@ -48,6 +48,9 @@ struct dequantizelinear
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
std::cout << "input[0] shape: " << inputs[0].type_string() << std::endl;
std::cout << "input[1] shape: " << inputs[1].type_string() << std::endl;
check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs.size() == 3 and inputs[0].type() != inputs[2].type())
{
......
......@@ -191,6 +191,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2);
std::cout << error << " " << threshold << std::endl;
if(out_error != nullptr)
*out_error = error;
return error <= threshold;
......
......@@ -153,16 +153,40 @@ struct find_mlir_op
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
gemm_based_op->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
std::cout << "Converting from module: " << pm->name() << std::endl;
pm->debug_print();
std::cout << "Converting to module: " << mm->name() << std::endl;
// Convert the parameters to the new module
std::transform(names.begin(),
names.end(),
ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
[&](auto name, auto input) {
//ins->debug_print();
if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv);
return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape()));
});
// Convert the literals to the new module
//for(auto&& ins : iterator_for(*pm))
//{
// if(ins->name() != "@literal")
// {
// continue;
// }
// auto shape = conv->get_shape().with_type(ins->get_shape().type());
// ;
// literal l{shape, ins->get_literal().data()};
// param_map[ins] = mm->add_literal(l);
// //ins->debug_print();
//}
// Insert the converted instructions and add the return
mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));
std::vector<instruction_ref> inputs;
......@@ -173,6 +197,7 @@ struct find_mlir_op
inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end());
mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
mm->debug_print();
}
};
......
......@@ -86,6 +86,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
//std::cout << src << std::endl;
return compile_hip_code_object(src, options);
}
......
......@@ -97,6 +97,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::tuple_type);
// clang-format off
......
......@@ -34,14 +34,35 @@ struct test_dequantizelinear : verify_program<test_dequantizelinear>
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::int8_type, {2, 2, 2}};
migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}};
migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}};
auto input1 = mm->add_parameter("x", sx);
auto input2 = mm->add_parameter("x_scale", ss);
auto x = mm->add_parameter("x", {migraphx::shape::int8_type, {1, 8, 4, 4}});
auto w = mm->add_parameter("w", {migraphx::shape::int8_type, {2, 8, 3, 3}});
auto b = mm->add_parameter("b", {migraphx::shape::int32_type, {1, 2, 2, 2}});
auto conv = mm->add_instruction(migraphx::make_op("quant_convolution"), x, w);
migraphx::shape ss{migraphx::shape::float_type, {1, 2, 2, 2}};
migraphx::shape sz{migraphx::shape::int32_type, {1, 2, 2, 2}};
//auto input2 = mm->add_parameter("x_scale", ss);
std::vector<float> datax = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8};
auto input2 = mm->add_literal(migraphx::literal(ss, datax));
auto input3 = mm->add_parameter("x_zero_point", sz);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), input1, input2, input3);
mm->add_return({r});
auto dequant =
mm->add_instruction(migraphx::make_op("dequantizelinear"), conv, input2, input3);
// conv, input2);
mm->add_return({dequant});
// mm->add_return({conv});
// auto add = mm->add_instruction(migraphx::make_op("add"), conv, b);
// mm->add_return({add});
// auto r = mm->add_instruction(migraphx::make_op("quantizelinear"), dequant, input2,
// input3); mm->add_return({r});
// auto s = migraphx::gpu::dump_mlir(m);
// migraphx::shape sx{migraphx::shape::int8_type, {2, 2, 2}};
// migraphx::shape ss{migraphx::shape::float_type, {2, 2, 2}};
// migraphx::shape sz{migraphx::shape::int8_type, {2, 2, 2}};
// auto input1 = mm->add_parameter("x", sx);
// auto input2 = mm->add_parameter("x_scale", ss);
// auto input3 = mm->add_parameter("x_zero_point", sz);
// auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), input1, input2,
// input3); mm->add_return({r});
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment