bm_knn.py 698 Bytes
Newer Older
Justin Johnson's avatar
Justin Johnson committed
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from itertools import product
4

Justin Johnson's avatar
Justin Johnson committed
5
from fvcore.common.benchmark import benchmark
Georgia Gkioxari's avatar
Georgia Gkioxari committed
6
from test_knn import TestKNN
Justin Johnson's avatar
Justin Johnson committed
7
8
9
10


def bm_knn() -> None:

Georgia Gkioxari's avatar
Georgia Gkioxari committed
11
    backends = ["cpu", "cuda:0"]
Justin Johnson's avatar
Justin Johnson committed
12

Georgia Gkioxari's avatar
Georgia Gkioxari committed
13
14
15
16
    kwargs_list = []
    Ns = [32]
    P1s = [256]
    P2s = [128, 512]
Justin Johnson's avatar
Justin Johnson committed
17
    Ds = [3]
Georgia Gkioxari's avatar
Georgia Gkioxari committed
18
19
20
21
22
    Ks = [24]
    test_cases = product(Ns, P1s, P2s, Ds, Ks, backends)
    for case in test_cases:
        N, P1, P2, D, K, b = case
        kwargs_list.append({"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "device": b})
Justin Johnson's avatar
Justin Johnson committed
23

Georgia Gkioxari's avatar
Georgia Gkioxari committed
24
    benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1)
Justin Johnson's avatar
Justin Johnson committed
25

Georgia Gkioxari's avatar
Georgia Gkioxari committed
26
    benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1)