Commit 492865a3 authored by yanyan's avatar yanyan
Browse files

faster subm indice generation by

unroll loop for common params
parent 9a23d934
...@@ -651,7 +651,8 @@ def main(algo=spconv.ConvAlgo.Native, dtype=torch.float32): ...@@ -651,7 +651,8 @@ def main(algo=spconv.ConvAlgo.Native, dtype=torch.float32):
indices_t = torch.from_numpy(indices).int().to(device).to(dtype) indices_t = torch.from_numpy(indices).int().to(device).to(dtype)
features_t = torch.from_numpy(features).to(device).to(dtype) features_t = torch.from_numpy(features).to(device).to(dtype)
features_dense_t = torch.from_numpy(features_dense).to(device).to(dtype) features_dense_t = torch.from_numpy(features_dense).to(device).to(
dtype)
net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d, net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d,
algo=algo).to(device).to(dtype) algo=algo).to(device).to(dtype)
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
...@@ -718,7 +719,8 @@ def main_subm(algo, dtype=torch.float32): ...@@ -718,7 +719,8 @@ def main_subm(algo, dtype=torch.float32):
indices_t = torch.from_numpy(indices).int().to(device).to(dtype) indices_t = torch.from_numpy(indices).int().to(device).to(dtype)
features_t = torch.from_numpy(features).to(device).to(dtype) features_t = torch.from_numpy(features).to(device).to(dtype)
features_dense_t = torch.from_numpy(features_dense).to(device).to(dtype) features_dense_t = torch.from_numpy(features_dense).to(device).to(
dtype)
net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d, net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d,
algo=algo).to(device).to(dtype) algo=algo).to(device).to(dtype)
net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p, net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
...@@ -750,8 +752,8 @@ def main_subm(algo, dtype=torch.float32): ...@@ -750,8 +752,8 @@ def main_subm(algo, dtype=torch.float32):
if __name__ == '__main__': if __name__ == '__main__':
main(algo=spconv.ConvAlgo.Native, dtype=torch.float32) # main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.float32)
main(algo=spconv.ConvAlgo.Native, dtype=torch.half) # main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.half)
# TestCase().assertAllClose(out_my, out_ref) # TestCase().assertAllClose(out_my, out_ref)
# unittest.main() # unittest.main()
# TestSpConv().testSpConv3d() TestSpConv().testSpConv3d()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment