Unverified Commit 63f313bf authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Fix mul match in operation def (#4262)

parent 0efabe96
...@@ -254,6 +254,13 @@ class AtenFloordiv(PyTorchOperation): ...@@ -254,6 +254,13 @@ class AtenFloordiv(PyTorchOperation):
return f'{output} = {inputs[0]} // {inputs[1]}' return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation): class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len'] _ori_type_name = ['aten::len']
......
...@@ -73,6 +73,7 @@ class TestModels(unittest.TestCase, ConvertMixin): ...@@ -73,6 +73,7 @@ class TestModels(unittest.TestCase, ConvertMixin):
def test_append_input_tensor(self): def test_append_input_tensor(self):
from typing import List from typing import List
class Net(nn.Module): class Net(nn.Module):
def __init__(self, num_nodes): def __init__(self, num_nodes):
super().__init__() super().__init__()
...@@ -80,6 +81,7 @@ class TestModels(unittest.TestCase, ConvertMixin): ...@@ -80,6 +81,7 @@ class TestModels(unittest.TestCase, ConvertMixin):
self.num_nodes = num_nodes self.num_nodes = num_nodes
for _ in range(num_nodes): for _ in range(num_nodes):
self.ops.append(nn.Linear(16, 16)) self.ops.append(nn.Linear(16, 16))
def forward(self, x: List[torch.Tensor]): def forward(self, x: List[torch.Tensor]):
state = x state = x
for ops in self.ops: for ops in self.ops:
...@@ -90,6 +92,19 @@ class TestModels(unittest.TestCase, ConvertMixin): ...@@ -90,6 +92,19 @@ class TestModels(unittest.TestCase, ConvertMixin):
x = torch.rand((1, 16), dtype=torch.float) x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], )) self.run_test(model, ([x], ))
def test_channels_shuffle(self):
class Net(nn.Module):
def forward(self, x):
bs, num_channels, height, width = x.size()
x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]
model = Net()
x = torch.rand((1, 64, 224, 224), dtype=torch.float)
self.run_test(model, (x, ))
def test_identity_node(self): def test_identity_node(self):
class Net(nn.Module): class Net(nn.Module):
def forward(self, x): def forward(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment