Commit d57daa6f authored by Patrick Labatut's avatar Patrick Labatut Committed by Facebook GitHub Bot
Browse files

Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
parent eb512ffd
......@@ -3,11 +3,7 @@
from .meshes import Meshes, join_meshes
from .pointclouds import Pointclouds
from .textures import Textures
from .utils import (
list_to_packed,
list_to_padded,
packed_to_list,
padded_to_list,
)
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
__all__ = [k for k in globals().keys() if not k.startswith("_")]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List
import torch
from . import utils as struct_utils
......@@ -314,14 +315,11 @@ class Meshes(object):
if isinstance(verts, list) and isinstance(faces, list):
self._verts_list = verts
self._faces_list = [
f[f.gt(-1).all(1)].to(torch.int64) if len(f) > 0 else f
for f in faces
f[f.gt(-1).all(1)].to(torch.int64) if len(f) > 0 else f for f in faces
]
self._N = len(self._verts_list)
self.device = torch.device("cpu")
self.valid = torch.zeros(
(self._N,), dtype=torch.bool, device=self.device
)
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0:
self.device = self._verts_list[0].device
self._num_verts_per_mesh = torch.tensor(
......@@ -348,18 +346,14 @@ class Meshes(object):
elif torch.is_tensor(verts) and torch.is_tensor(faces):
if verts.size(2) != 3 and faces.size(2) != 3:
raise ValueError(
"Verts and Faces tensors have incorrect dimensions."
)
raise ValueError("Verts and Faces tensors have incorrect dimensions.")
self._verts_padded = verts
self._faces_padded = faces.to(torch.int64)
self._N = self._verts_padded.shape[0]
self._V = self._verts_padded.shape[1]
self.device = self._verts_padded.device
self.valid = torch.zeros(
(self._N,), dtype=torch.bool, device=self.device
)
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0:
# Check that padded faces - which have value -1 - are at the
# end of the tensors
......@@ -400,12 +394,8 @@ class Meshes(object):
# Set the num verts/faces on the textures if present.
if self.textures is not None:
self.textures._num_faces_per_mesh = (
self._num_faces_per_mesh.tolist()
)
self.textures._num_verts_per_mesh = (
self._num_verts_per_mesh.tolist()
)
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
def __len__(self):
return self._N
......@@ -665,8 +655,7 @@ class Meshes(object):
self._verts_padded_to_packed_idx = torch.cat(
[
torch.arange(v, dtype=torch.int64, device=self.device)
+ i * self._V
torch.arange(v, dtype=torch.int64, device=self.device) + i * self._V
for (i, v) in enumerate(self._num_verts_per_mesh)
],
dim=0,
......@@ -706,15 +695,10 @@ class Meshes(object):
tensor of normals of shape (N, max(V_n), 3).
"""
if self.isempty():
return torch.zeros(
(self._N, 0, 3), dtype=torch.float32, device=self.device
)
return torch.zeros((self._N, 0, 3), dtype=torch.float32, device=self.device)
verts_normals_list = self.verts_normals_list()
return struct_utils.list_to_padded(
verts_normals_list,
(self._V, 3),
pad_value=0.0,
equisized=self.equisized,
verts_normals_list, (self._V, 3), pad_value=0.0, equisized=self.equisized
)
def faces_normals_packed(self):
......@@ -750,15 +734,10 @@ class Meshes(object):
tensor of normals of shape (N, max(F_n), 3).
"""
if self.isempty():
return torch.zeros(
(self._N, 0, 3), dtype=torch.float32, device=self.device
)
return torch.zeros((self._N, 0, 3), dtype=torch.float32, device=self.device)
faces_normals_list = self.faces_normals_list()
return struct_utils.list_to_padded(
faces_normals_list,
(self._F, 3),
pad_value=0.0,
equisized=self.equisized,
faces_normals_list, (self._F, 3), pad_value=0.0, equisized=self.equisized
)
def faces_areas_packed(self):
......@@ -797,9 +776,7 @@ class Meshes(object):
return
faces_packed = self.faces_packed()
verts_packed = self.verts_packed()
face_areas, face_normals = mesh_face_areas_normals(
verts_packed, faces_packed
)
face_areas, face_normals = mesh_face_areas_normals(verts_packed, faces_packed)
self._faces_areas_packed = face_areas
self._faces_normals_packed = face_normals
......@@ -813,9 +790,7 @@ class Meshes(object):
refresh: Set to True to force recomputation of vertex normals.
Default: False.
"""
if not (
refresh or any(v is None for v in [self._verts_normals_packed])
):
if not (refresh or any(v is None for v in [self._verts_normals_packed])):
return
if self.isempty():
......@@ -867,8 +842,7 @@ class Meshes(object):
Computes the padded version of meshes from verts_list and faces_list.
"""
if not (
refresh
or any(v is None for v in [self._verts_padded, self._faces_padded])
refresh or any(v is None for v in [self._verts_padded, self._faces_padded])
):
return
......@@ -887,16 +861,10 @@ class Meshes(object):
)
else:
self._faces_padded = struct_utils.list_to_padded(
faces_list,
(self._F, 3),
pad_value=-1.0,
equisized=self.equisized,
faces_list, (self._F, 3), pad_value=-1.0, equisized=self.equisized
)
self._verts_padded = struct_utils.list_to_padded(
verts_list,
(self._V, 3),
pad_value=0.0,
equisized=self.equisized,
verts_list, (self._V, 3), pad_value=0.0, equisized=self.equisized
)
# TODO(nikhilar) Improve performance of _compute_packed.
......@@ -1055,9 +1023,7 @@ class Meshes(object):
face_to_edge = inverse_idxs[face_to_edge]
self._faces_packed_to_edges_packed = face_to_edge
num_edges_per_mesh = torch.zeros(
self._N, dtype=torch.int32, device=self.device
)
num_edges_per_mesh = torch.zeros(self._N, dtype=torch.int32, device=self.device)
ones = torch.ones(1, dtype=torch.int32, device=self.device).expand(
self._edges_packed_to_mesh_idx.shape
)
......
......@@ -176,17 +176,13 @@ class Pointclouds(object):
self._points_list = points
self._N = len(self._points_list)
self.device = torch.device("cpu")
self.valid = torch.zeros(
(self._N,), dtype=torch.bool, device=self.device
)
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
self._num_points_per_cloud = []
if self._N > 0:
for p in self._points_list:
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
raise ValueError(
"Clouds in list must be of shape Px3 or empty"
)
raise ValueError("Clouds in list must be of shape Px3 or empty")
self.device = self._points_list[0].device
num_points_per_cloud = torch.tensor(
......@@ -210,9 +206,7 @@ class Pointclouds(object):
self._N = self._points_padded.shape[0]
self._P = self._points_padded.shape[1]
self.device = self._points_padded.device
self.valid = torch.ones(
(self._N,), dtype=torch.bool, device=self.device
)
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
self._num_points_per_cloud = torch.tensor(
[self._P] * self._N, device=self.device
)
......@@ -260,9 +254,7 @@ class Pointclouds(object):
if isinstance(aux_input, list):
if len(aux_input) != self._N:
raise ValueError(
"Points and auxiliary input must be the same length."
)
raise ValueError("Points and auxiliary input must be the same length.")
for p, d in zip(self._num_points_per_cloud, aux_input):
if p != d.shape[0]:
raise ValueError(
......@@ -282,9 +274,7 @@ class Pointclouds(object):
return aux_input, None, aux_input_C
elif torch.is_tensor(aux_input):
if aux_input.dim() != 3:
raise ValueError(
"Auxiliary input tensor has incorrect dimensions."
)
raise ValueError("Auxiliary input tensor has incorrect dimensions.")
if self._N != aux_input.shape[0]:
raise ValueError("Points and inputs must be the same length.")
if self._P != aux_input.shape[1]:
......@@ -531,8 +521,7 @@ class Pointclouds(object):
else:
self._padded_to_packed_idx = torch.cat(
[
torch.arange(v, dtype=torch.int64, device=self.device)
+ i * self._P
torch.arange(v, dtype=torch.int64, device=self.device) + i * self._P
for (i, v) in enumerate(self._num_points_per_cloud)
],
dim=0,
......@@ -551,9 +540,7 @@ class Pointclouds(object):
self._normals_padded, self._features_padded = None, None
if self.isempty():
self._points_padded = torch.zeros(
(self._N, 0, 3), device=self.device
)
self._points_padded = torch.zeros((self._N, 0, 3), device=self.device)
else:
self._points_padded = struct_utils.list_to_padded(
self.points_list(),
......@@ -621,9 +608,7 @@ class Pointclouds(object):
points_list_to_packed = struct_utils.list_to_packed(points_list)
self._points_packed = points_list_to_packed[0]
if not torch.allclose(
self._num_points_per_cloud, points_list_to_packed[1]
):
if not torch.allclose(self._num_points_per_cloud, points_list_to_packed[1]):
raise ValueError("Inconsistent list to packed conversion")
self._cloud_to_packed_first_idx = points_list_to_packed[2]
self._packed_to_cloud_idx = points_list_to_packed[3]
......@@ -696,13 +681,9 @@ class Pointclouds(object):
if other._N > 0:
other._points_list = [v.to(device) for v in other.points_list()]
if other._normals_list is not None:
other._normals_list = [
n.to(device) for n in other.normals_list()
]
other._normals_list = [n.to(device) for n in other.normals_list()]
if other._features_list is not None:
other._features_list = [
f.to(device) for f in other.features_list()
]
other._features_list = [f.to(device) for f in other.features_list()]
for k in self._INTERNAL_TENSORS:
v = getattr(self, k)
if torch.is_tensor(v):
......@@ -892,16 +873,11 @@ class Pointclouds(object):
for features in self.features_list():
new_features_list.extend(features.clone() for _ in range(N))
return Pointclouds(
points=new_points_list,
normals=new_normals_list,
features=new_features_list,
points=new_points_list, normals=new_normals_list, features=new_features_list
)
def update_padded(
self,
new_points_padded,
new_normals_padded=None,
new_features_padded=None,
self, new_points_padded, new_normals_padded=None, new_features_padded=None
):
"""
Returns a Pointcloud structure with updated padded tensors and copies of
......@@ -920,13 +896,9 @@ class Pointclouds(object):
def check_shapes(x, size):
if x.shape[0] != size[0]:
raise ValueError(
"new values must have the same batch dimension."
)
raise ValueError("new values must have the same batch dimension.")
if x.shape[1] != size[1]:
raise ValueError(
"new values must have the same number of points."
)
raise ValueError("new values must have the same number of points.")
if size[2] is not None:
if x.shape[2] != size[2]:
raise ValueError(
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Optional, Union
import torch
import torchvision.transforms as T
......@@ -233,11 +234,7 @@ class Textures(object):
if all(
v is not None
for v in [
self._faces_uvs_padded,
self._verts_uvs_padded,
self._maps_padded,
]
for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded]
):
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Union
import torch
......@@ -38,9 +39,7 @@ def list_to_padded(
pad_dim1 = max(y.shape[1] for y in x if len(y) > 0)
else:
if len(pad_size) != 2:
raise ValueError(
"Pad size must contain target size for 1st and 2nd dim"
)
raise ValueError("Pad size must contain target size for 1st and 2nd dim")
pad_dim0, pad_dim1 = pad_size
N = len(x)
......@@ -55,9 +54,7 @@ def list_to_padded(
return x_padded
def padded_to_list(
x: torch.Tensor, split_size: Union[list, tuple, None] = None
):
def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None):
r"""
Transforms a padded tensor of shape (N, M, K) into a list of N tensors
of shape (Mi, Ki) where (Mi, Ki) is specified in split_size(i), or of shape
......@@ -81,9 +78,7 @@ def padded_to_list(
N = len(split_size)
if x.shape[0] != N:
raise ValueError(
"Split size must be of same length as inputs first dimension"
)
raise ValueError("Split size must be of same length as inputs first dimension")
for i in range(N):
if isinstance(split_size[i], int):
......@@ -119,9 +114,7 @@ def list_to_packed(x: List[torch.Tensor]):
"""
N = len(x)
num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device)
item_packed_first_idx = torch.zeros(
N, dtype=torch.int64, device=x[0].device
)
item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device)
item_packed_to_list_idx = []
cur = 0
for i, y in enumerate(x):
......@@ -187,9 +180,7 @@ def padded_to_packed(
N, M, D = x.shape
if split_size is not None and pad_value is not None:
raise ValueError(
"Only one of split_size or pad_value should be provided."
)
raise ValueError("Only one of split_size or pad_value should be provided.")
x_packed = x.reshape(-1, D) # flatten padded
......@@ -205,9 +196,7 @@ def padded_to_packed(
# Convert to packed using split sizes
N = len(split_size)
if x.shape[0] != N:
raise ValueError(
"Split size must be of same length as inputs first dimension"
)
raise ValueError("Split size must be of same length as inputs first dimension")
if not all(isinstance(i, int) for i in split_size):
raise ValueError(
......
......@@ -22,4 +22,5 @@ from .so3 import (
)
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
__all__ = [k for k in globals().keys() if not k.startswith("_")]
......@@ -2,6 +2,7 @@
import functools
from typing import Optional
import torch
......@@ -155,9 +156,7 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
)
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
......@@ -246,10 +245,7 @@ def matrix_to_euler_angles(matrix, convention: str):
def random_quaternions(
n: int,
dtype: Optional[torch.dtype] = None,
device=None,
requires_grad=False,
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
"""
Generate random quaternions representing rotations,
......@@ -266,19 +262,14 @@ def random_quaternions(
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn(
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int,
dtype: Optional[torch.dtype] = None,
device=None,
requires_grad=False,
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
"""
Generate random rotations as 3x3 rotation matrices.
......
......@@ -3,6 +3,7 @@
import torch
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
......@@ -65,9 +66,7 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
raise ValueError(
"A matrix has trace outside valid range [-1-eps,3+eps]."
)
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
# clamp to valid range
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
......
......@@ -3,6 +3,7 @@
import math
import warnings
from typing import Optional
import torch
from .rotation_conversions import _axis_angle_rotation
......@@ -230,9 +231,7 @@ class Transform3d:
# the transformations with get_matrix(), this correctly
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [
t.inverse() for t in reversed(self._transforms)
]
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
last = Transform3d(device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
......@@ -334,9 +333,7 @@ class Transform3d:
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
return self.compose(
RotateAxisAngle(device=self.device, *args, **kwargs)
)
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
def clone(self):
"""
......@@ -388,9 +385,7 @@ class Transform3d:
class Translate(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
"""
Create a new Transform3d representing 3D translations.
......@@ -424,9 +419,7 @@ class Translate(Transform3d):
class Scale(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
......@@ -444,9 +437,7 @@ class Scale(Transform3d):
- 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(
x, y, z, dtype, device, "scale", allow_singleton=True
)
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
N = xyz.shape[0]
# TODO: Can we do this all in one go somehow?
......@@ -469,11 +460,7 @@ class Scale(Transform3d):
class Rotate(Transform3d):
def __init__(
self,
R,
dtype=torch.float32,
device: str = "cpu",
orthogonal_tol: float = 1e-5,
self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5
):
"""
Create a new Transform3d representing 3D rotation using a rotation
......@@ -562,9 +549,7 @@ def _handle_coord(c, dtype, device):
return c
def _handle_input(
x, y, z, dtype, device, name: str, allow_singleton: bool = False
):
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
......
......@@ -3,4 +3,5 @@
from .ico_sphere import ico_sphere
from .torus import torus
__all__ = [k for k in globals().keys() if not k.startswith("_")]
......@@ -2,10 +2,10 @@
import torch
from pytorch3d.ops.subdivide_meshes import SubdivideMeshes
from pytorch3d.structures.meshes import Meshes
# Vertex coordinates for a level 0 ico-sphere.
_ico_verts0 = [
[-0.5257, 0.8507, 0.0000],
......
......@@ -3,8 +3,8 @@
from itertools import tee
from math import cos, pi, sin
from typing import Iterator, Optional, Tuple
import torch
import torch
from pytorch3d.structures.meshes import Meshes
......@@ -16,11 +16,7 @@ def _make_pair_range(N: int) -> Iterator[Tuple[int, int]]:
def torus(
r: float,
R: float,
sides: int,
rings: int,
device: Optional[torch.device] = None,
r: float, R: float, sides: int, rings: int, device: Optional[torch.device] = None
) -> Meshes:
"""
Create vertices and faces for a torus.
......
......@@ -4,10 +4,12 @@
import argparse
import json
import os
import nbformat
from bs4 import BeautifulSoup
from nbconvert import HTMLExporter, ScriptExporter
TEMPLATE = """const CWD = process.cwd();
const React = require('react');
......@@ -41,9 +43,7 @@ def gen_tutorials(repo_dir: str) -> None:
Also create ipynb and py versions of tutorial in Docusaurus site for
download.
"""
with open(
os.path.join(repo_dir, "website", "tutorials.json"), "r"
) as infile:
with open(os.path.join(repo_dir, "website", "tutorials.json"), "r") as infile:
tutorial_config = json.loads(infile.read())
tutorial_ids = {x["id"] for v in tutorial_config.values() for x in v}
......@@ -107,10 +107,7 @@ if __name__ == "__main__":
description="Generate JS, HTML, ipynb, and py files for tutorials."
)
parser.add_argument(
"--repo_dir",
metavar="path",
required=True,
help="PyTorch3D repo directory.",
"--repo_dir", metavar="path", required=True, help="PyTorch3D repo directory."
)
args = parser.parse_args()
gen_tutorials(args.repo_dir)
......@@ -3,8 +3,9 @@
import glob
import os
from setuptools import find_packages, setup
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
......
......@@ -2,8 +2,8 @@
from itertools import product
from fvcore.common.benchmark import benchmark
from fvcore.common.benchmark import benchmark
from test_blending import TestBlending
......@@ -18,12 +18,7 @@ def bm_blending() -> None:
for case in test_cases:
n, s, k, d = case
kwargs_list.append(
{
"num_meshes": n,
"image_size": s,
"faces_per_pixel": k,
"device": d,
}
{"num_meshes": n, "image_size": s, "faces_per_pixel": k, "device": d}
)
benchmark(
......
......@@ -3,7 +3,6 @@
import torch
from fvcore.common.benchmark import benchmark
from test_chamfer import TestChamfer
......@@ -25,9 +24,4 @@ def bm_chamfer() -> None:
{"batch_size": 1, "P1": 1000, "P2": 3000, "return_normals": False},
{"batch_size": 1, "P1": 1000, "P2": 30000, "return_normals": True},
]
benchmark(
TestChamfer.chamfer_with_init,
"CHAMFER",
kwargs_list,
warmup_iters=1,
)
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from fvcore.common.benchmark import benchmark
from test_cubify import TestCubify
......@@ -11,6 +10,4 @@ def bm_cubify() -> None:
{"batch_size": 64, "V": 16},
{"batch_size": 16, "V": 32},
]
benchmark(
TestCubify.cubify_with_init, "CUBIFY", kwargs_list, warmup_iters=1
)
benchmark(TestCubify.cubify_with_init, "CUBIFY", kwargs_list, warmup_iters=1)
......@@ -2,9 +2,9 @@
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from test_face_areas_normals import TestFaceAreasNormals
......
......@@ -2,9 +2,9 @@
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from test_graph_conv import TestGraphConv
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d import _C
from pytorch3d.ops.knn import _knn_points_idx_naive
......@@ -32,9 +32,7 @@ def benchmark_knn_cuda_versions() -> None:
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K, "v": version})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(
knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1
)
benchmark(knn_cuda_with_init, "KNN_CUDA_VERSIONS", knn_kwargs, warmup_iters=1)
benchmark(nn_cuda_with_init, "NN_CUDA", nn_kwargs, warmup_iters=1)
......@@ -50,10 +48,7 @@ def benchmark_knn_cuda_vs_naive() -> None:
if P <= 4096:
naive_kwargs.append({"N": N, "D": D, "P": P, "K": K})
benchmark(
knn_python_cuda_with_init,
"KNN_CUDA_PYTHON",
naive_kwargs,
warmup_iters=1,
knn_python_cuda_with_init, "KNN_CUDA_PYTHON", naive_kwargs, warmup_iters=1
)
benchmark(knn_cuda_with_init, "KNN_CUDA", knn_kwargs, warmup_iters=1)
......@@ -68,9 +63,7 @@ def benchmark_knn_cpu() -> None:
knn_kwargs.append({"N": N, "D": D, "P": P, "K": K})
for N, P, D in product(Ns, Ps, Ds):
nn_kwargs.append({"N": N, "D": D, "P": P})
benchmark(
knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1
)
benchmark(knn_python_cpu_with_init, "KNN_CPU_PYTHON", knn_kwargs, warmup_iters=1)
benchmark(knn_cpu_with_init, "KNN_CPU_CPP", knn_kwargs, warmup_iters=1)
benchmark(nn_cpu_with_init, "NN_CPU_CPP", nn_kwargs, warmup_iters=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment