bm_nearest_neighbor_points.py 896 Bytes
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from itertools import product
import torch
from fvcore.common.benchmark import benchmark

from test_nearest_neighbor_points import TestNearestNeighborPoints


def bm_nn_points() -> None:
    kwargs_list = []

    N = [1, 4, 32]
    D = [3, 4]
    P1 = [1, 128]
    P2 = [32, 128]
    test_cases = product(N, D, P1, P2)
    for case in test_cases:
        n, d, p1, p2 = case
        kwargs_list.append({"N": n, "D": d, "P1": p1, "P2": p2})

    benchmark(
        TestNearestNeighborPoints.bm_nn_points_python_with_init,
        "NN_PYTHON",
        kwargs_list,
        warmup_iters=1,
    )

    if torch.cuda.is_available():
        benchmark(
            TestNearestNeighborPoints.bm_nn_points_cuda_with_init,
            "NN_CUDA",
            kwargs_list,
            warmup_iters=1,
        )