Unverified Commit 0058fbc2 authored by Louis-J's avatar Louis-J Committed by GitHub
Browse files

test(speedup): add tests for aten::split and aten::leaky_relu (#5070)

parent 792a0a91
...@@ -47,8 +47,8 @@ class TorchModel1(torch.nn.Module): ...@@ -47,8 +47,8 @@ class TorchModel1(torch.nn.Module):
""" """
test for: test for:
add, sub, mul, div, exp, matmul, add, sub, mul, div, exp, matmul,
relu, gelu, tanh, silu, sigmod, softmax, relu, gelu, tanh, silu, sigmod, softmax, leaky_relu,
size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd, size, unsqueeze, flatten, cat, slice, reshape, transpose, t, select, permute, constant_pad_nd, split
mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d, mean, avg_pool2d, max_pool2d, sum, adaptive_avg_pool2d,
to, Int, view, to, Int, view,
type_as, expand_as, contiguous, type_as, expand_as, contiguous,
...@@ -89,8 +89,9 @@ class TorchModel1(torch.nn.Module): ...@@ -89,8 +89,9 @@ class TorchModel1(torch.nn.Module):
x = self.conv1(x) x = self.conv1(x)
y1 = self.pool1(F.relu(x)) y1 = self.pool1(F.relu(x))
y2 = self.pool1(F.gelu(x)) y2 = self.pool1(F.gelu(x))
y3 = self.pool1(F.leaky_relu(x))
x = y1 + y2 x = y1 + y2 + y3
x = x + 0.00001 x = x + 0.00001
...@@ -133,8 +134,9 @@ class TorchModel1(torch.nn.Module): ...@@ -133,8 +134,9 @@ class TorchModel1(torch.nn.Module):
x = x.type_as(x) x = x.type_as(x)
x = x.expand_as(x) x = x.expand_as(x)
x = torch.matmul(x, x.t()) x = torch.matmul(x, x.t())
x = torch.cat([x, x], dim=1) x = torch.split(x, 1, dim=1)
# x = self.cond(x) x = torch.cat(x, dim=1)
# x = self.cond(x) # condition is not support now
x = self.asub(x) x = self.asub(x)
x = torch.constant_pad_nd(x, (1,1,1,1), 3.14159) x = torch.constant_pad_nd(x, (1,1,1,1), 3.14159)
...@@ -159,7 +161,7 @@ class AutoConvTestCase(unittest.TestCase): ...@@ -159,7 +161,7 @@ class AutoConvTestCase(unittest.TestCase):
assert 0.45 < real_sparsity_list[0]['total_sparsity'] < 0.75 assert 0.45 < real_sparsity_list[0]['total_sparsity'] < 0.75
print('the shape of output of the infer:', model(dummy_input).shape) print('the shape of output of the infer:', model(dummy_input).shape)
assert model(dummy_input).shape == torch.Size((5, 8)) assert model(dummy_input).shape == torch.Size((5, 5))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment