".github/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "24c035f2e387b48d72ed40ddb5820ae49a0dc218"
Unverified Commit 1eb5a1d4 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic ONNX Matmul (#1466)

Extends parse_matmul.hpp to handle dynamic input shapes
Does not support broadcasting of the outer dimensions for dynamic shapes at this time
parent 3fb5c0ef
...@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto l0 = args[0]; auto a0 = args[0];
auto l1 = args[1]; auto a1 = args[1];
auto l0_lens = l0->get_shape().lens(); auto s0 = a0->get_shape();
auto l1_lens = l1->get_shape().lens(); auto s1 = a1->get_shape();
// args[0] is a vector, prepend 1 to the shape instruction_ref dot_res;
bool is_a_prepended = false; bool is_a_prepended = false;
if(l0_lens.size() == 1) bool is_b_appended = false;
if(s0.ndim() == 1)
{ {
is_a_prepended = true; is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1); a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
l0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
} }
if(s1.ndim() == 1)
bool is_b_appended = false;
if(l1_lens.size() == 1)
{ {
is_b_appended = true; is_b_appended = true;
l1_lens.push_back(1); a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
l1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]); }
if(s0.dynamic() or s1.dynamic())
{
if(opd.op_name == "quant_dot")
{
MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported");
} }
auto s0_dds = a0->get_shape().to_dynamic().dyn_dims();
auto s1_dds = a1->get_shape().to_dynamic().dyn_dims();
instruction_ref bl0 = l0; // TODO: handling this case requires a new multibroadcast mode
instruction_ref bl1 = l1;
if(not std::equal( if(not std::equal(
l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend())) s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend()))
{ {
auto l0_it = l0_lens.begin() + l0_lens.size() - 2; MIGRAPHX_THROW("PARSE_MATMUL: dynamic shape broadcasting not supported");
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it); }
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it); dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); }
else
{
auto s0_lens = a0->get_shape().lens();
auto s1_lens = a1->get_shape().lens();
instruction_ref ba0 = a0;
instruction_ref ba1 = a1;
// try broadcasting if dimensions other than last two do not match
if(not std::equal(
s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend()))
{
auto l0_it = s0_lens.begin() + s0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0_lens.begin(), l0_it);
auto l1_it = s1_lens.begin() + s1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1_lens.begin(), l1_it);
auto output_lens =
compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens; l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end()); l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end());
l1_broadcasted_lens = output_lens; l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end());
if(l0_lens != l0_broadcasted_lens) if(s0_lens != l0_broadcasted_lens)
{ {
bl0 = info.add_instruction( ba0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0); make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0);
} }
if(l1_lens != l1_broadcasted_lens) if(s1_lens != l1_broadcasted_lens)
{ {
bl1 = info.add_instruction( ba1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
}
} }
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
} }
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); // squeeze the appended or prepended dimensions
int64_t num_axis = dot_res->get_shape().ndim();
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res); dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
...@@ -3563,6 +3563,81 @@ def matmul_vv_test(): ...@@ -3563,6 +3563,81 @@ def matmul_vv_test():
return ([node], [m1, m2], [y]) return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_mm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7, None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, None])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_mv_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, 1])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_vm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7, None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, None])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_vv_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_broadcast_error():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [5, 7, None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, None])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test() @onnx_test()
def matmulinteger_test(): def matmulinteger_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [3, 6, 16]) m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [3, 6, 16])
...@@ -3578,6 +3653,21 @@ def matmulinteger_test(): ...@@ -3578,6 +3653,21 @@ def matmulinteger_test():
return ([node], [m1, m2], [y]) return ([node], [m1, m2], [y])
@onnx_test()
def matmulinteger_dyn_error():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [None, 6, 16])
m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [None, 16, 8])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [None, 6, 8])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test() @onnx_test()
def max_test(): def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
...@@ -3432,6 +3432,92 @@ TEST_CASE(matmul_vv_test) ...@@ -3432,6 +3432,92 @@ TEST_CASE(matmul_vv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmul_dyn_mm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {1, 5, 3}}});
auto ret = migraphx::add_apply_alpha_beta(*mm, {l0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {1, 5, 3}};
auto prog = parse_onnx("matmul_dyn_mm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_mv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = migraphx::add_apply_alpha_beta(*mm, {l0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
auto prog = parse_onnx("matmul_dyn_mv_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_vm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {4, 10, 8}}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = migraphx::add_apply_alpha_beta(*mm, {sl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {4, 10, 8}};
auto prog = parse_onnx("matmul_dyn_vm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_vv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 8, 7};
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {dd}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {dd}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res =
migraphx::add_apply_alpha_beta(*mm, {sl0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
auto prog = parse_onnx("matmul_dyn_vv_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_broadcast_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmul_dyn_broadcast_error.onnx", options); }));
}
TEST_CASE(matmulinteger_test) TEST_CASE(matmulinteger_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -3445,6 +3531,13 @@ TEST_CASE(matmulinteger_test) ...@@ -3445,6 +3531,13 @@ TEST_CASE(matmulinteger_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(matmulinteger_dyn_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmulinteger_dyn_error.onnx", options); }));
}
TEST_CASE(max_test) TEST_CASE(max_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment