Commit 791a0681 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

avoid math.prod for python 3.7

Summary: This makes the new volumes tutorial work on google colab.

Reviewed By: kjchalup

Differential Revision: D38501906

fbshipit-source-id: a606a357e929dae903dc4d9067bd1519f05b1458
parent c49ebad2
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
""" """
Some functions which depend on PyTorch versions. Some functions which depend on PyTorch or Python versions.
""" """
...@@ -79,3 +79,12 @@ def meshgrid_ij( ...@@ -79,3 +79,12 @@ def meshgrid_ij(
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got # pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
# `Union[Sequence[Tensor], Tensor]`. # `Union[Sequence[Tensor], Tensor]`.
return torch.meshgrid(*A) return torch.meshgrid(*A)
def prod(iterable, *, start=1):
"""
Like math.prod in Python 3.8 and later.
"""
for i in iterable:
start *= i
return start
...@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import tqdm import tqdm
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.metrics import ( from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase, RegularizationMetricsBase,
ViewMetricsBase, ViewMetricsBase,
...@@ -919,7 +920,7 @@ def _chunk_generator( ...@@ -919,7 +920,7 @@ def _chunk_generator(
f"by n_pts_per_ray ({n_pts_per_ray})" f"by n_pts_per_ray ({n_pts_per_ray})"
) )
n_rays = math.prod(spatial_dim) n_rays = prod(spatial_dim)
# special handling for raytracing-based methods # special handling for raytracing-based methods
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
chunk_size_in_rays = -(-n_rays // n_chunks) chunk_size_in_rays = -(-n_rays // n_chunks)
...@@ -935,9 +936,9 @@ def _chunk_generator( ...@@ -935,9 +936,9 @@ def _chunk_generator(
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx :, start_idx:end_idx
], ],
lengths=ray_bundle.lengths.reshape( lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
batch_size, math.prod(spatial_dim), n_pts_per_ray :, start_idx:end_idx
)[:, start_idx:end_idx], ],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
) )
extra_args = kwargs.copy() extra_args = kwargs.copy()
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
from pytorch3d.common.compat import prod
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
...@@ -52,7 +52,7 @@ def create_embeddings_for_implicit_function( ...@@ -52,7 +52,7 @@ def create_embeddings_for_implicit_function(
embeds = torch.empty( embeds = torch.empty(
bs, bs,
1, 1,
math.prod(spatial_size), prod(spatial_size),
pts_per_ray, pts_per_ray,
0, 0,
dtype=xyz_world.dtype, dtype=xyz_world.dtype,
...@@ -62,7 +62,7 @@ def create_embeddings_for_implicit_function( ...@@ -62,7 +62,7 @@ def create_embeddings_for_implicit_function(
embeds = xyz_embedding_function(ray_points_for_embed).reshape( embeds = xyz_embedding_function(ray_points_for_embed).reshape(
bs, bs,
1, 1,
math.prod(spatial_size), prod(spatial_size),
pts_per_ray, pts_per_ray,
-1, -1,
) # flatten spatial, add n_src dim ) # flatten spatial, add n_src dim
...@@ -73,7 +73,7 @@ def create_embeddings_for_implicit_function( ...@@ -73,7 +73,7 @@ def create_embeddings_for_implicit_function(
embed_shape = ( embed_shape = (
bs, bs,
embeds_viewpooled.shape[1], embeds_viewpooled.shape[1],
math.prod(spatial_size), prod(spatial_size),
pts_per_ray, pts_per_ray,
-1, -1,
) )
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
# implicit_differentiable_renderer.py # implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv # Copyright (c) 2020 Lior Yariv
import functools import functools
import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
get_default_args_field, get_default_args_field,
registry, registry,
...@@ -105,7 +105,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign ...@@ -105,7 +105,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
# object_mask: silhouette of the object # object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape batch_size, *spatial_size, _ = ray_bundle.lengths.shape
num_pixels = math.prod(spatial_size) num_pixels = prod(spatial_size)
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3) cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3) ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment