Commit a0792e88 authored by Ruilong Li's avatar Ruilong Li
Browse files

profile cub

parent 75a7b021
......@@ -175,9 +175,40 @@ def test_exclusive_prod():
assert torch.allclose(outputs1, outputs3)
assert torch.allclose(grad1, grad3)
def profile():
import tqdm
from nerfacc.scan import inclusive_sum
from nerfacc.scan_cub import inclusive_sum_cub
torch.manual_seed(42)
data = torch.rand((8192, 8192), device=device, requires_grad=True)
chunk_starts = torch.arange(
0, data.numel(), data.shape[1], device=device, dtype=torch.long
)
chunk_cnts = torch.full(
(data.shape[0],), data.shape[1], dtype=torch.long, device=device
)
packed_info = torch.stack([chunk_starts, chunk_cnts], dim=-1)
flatten_data = data.flatten()
torch.cuda.synchronize()
for _ in tqdm.trange(2000):
outputs2 = inclusive_sum(flatten_data, packed_info=packed_info)
outputs2.sum().backward()
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
torch.cuda.synchronize()
for _ in tqdm.trange(2000):
outputs3 = inclusive_sum_cub(flatten_data, indices)
outputs3.sum().backward()
if __name__ == "__main__":
test_inclusive_sum()
test_exclusive_sum()
test_inclusive_prod()
test_exclusive_prod()
# profile()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment