Commit 3b7ab22d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

MeshRasterizerOpenGL import fixes

Summary: Only import it if you ask for it.

Reviewed By: kjchalup

Differential Revision: D38327167

fbshipit-source-id: 3f05231f26eda582a63afc71b669996342b0c6f9
parent 5bf6d532
......@@ -7,7 +7,6 @@
import contextlib
import logging
import re
from typing import List
@contextlib.contextmanager
......
......@@ -65,11 +65,6 @@ from .mesh import (
TexturesVertex,
)
try:
from .opengl import EGLContext, global_device_context_store, MeshRasterizerOpenGL
except (ImportError, ModuleNotFoundError):
pass # opengl or pycuda.gl not available, or pytorch3_opengl not in TARGETS.
from .points import (
AlphaCompositor,
NormWeightedCompositor,
......
......@@ -10,8 +10,6 @@ import torch
import torch.nn as nn
from ...structures.meshes import Meshes
from .rasterizer import MeshRasterizer
# A renderer class should be initialized with a
# function for rasterization and a function for shading.
......@@ -32,11 +30,11 @@ from .rasterizer import MeshRasterizer
class MeshRenderer(nn.Module):
"""
A class for rendering a batch of heterogeneous meshes. The class should
be initialized with a rasterizer and shader class which each have a forward
function.
be initialized with a rasterizer (a MeshRasterizer or a MeshRasterizerOpenGL)
and shader class which each have a forward function.
"""
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
def __init__(self, rasterizer, shader) -> None:
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
......@@ -69,8 +67,8 @@ class MeshRenderer(nn.Module):
class MeshRendererWithFragments(nn.Module):
"""
A class for rendering a batch of heterogeneous meshes. The class should
be initialized with a rasterizer and shader class which each have a forward
function.
be initialized with a rasterizer (a MeshRasterizer or a MeshRasterizerOpenGL)
and shader class which each have a forward function.
In the forward pass this class returns the `fragments` from which intermediate
values such as the depth map can be easily extracted e.g.
......@@ -80,7 +78,7 @@ class MeshRendererWithFragments(nn.Module):
depth = fragments.zbuf
"""
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
def __init__(self, rasterizer, shader) -> None:
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
......
......@@ -130,8 +130,9 @@ class MeshRasterizerOpenGL(nn.Module):
Fragments output by MeshRasterizerOpenGL and MeshRasterizer should have near
identical pix_to_face, bary_coords and zbuf. However, MeshRasterizerOpenGL does not
return Fragments.dists which is only relevant to SoftPhongShader which doesn't work
with MeshRasterizerOpenGL (because it is not differentiable).
return Fragments.dists which is only relevant to SoftPhongShader and
SoftSilhouetteShader. These do not work with MeshRasterizerOpenGL (because it is
not differentiable).
"""
def __init__(
......
......@@ -17,7 +17,6 @@ from pytorch3d.renderer import (
look_at_view_transform,
Materials,
MeshRasterizer,
MeshRasterizerOpenGL,
MeshRenderer,
PointLights,
RasterizationSettings,
......@@ -30,6 +29,7 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes_python,
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.opengl import MeshRasterizerOpenGL
from pytorch3d.renderer.points import (
AlphaCompositor,
PointsRasterizationSettings,
......
......@@ -15,7 +15,6 @@ from pytorch3d.renderer import (
FoVPerspectiveCameras,
look_at_view_transform,
MeshRasterizer,
MeshRasterizerOpenGL,
OrthographicCameras,
PerspectiveCameras,
PointsRasterizationSettings,
......@@ -27,6 +26,7 @@ from pytorch3d.renderer.opengl.rasterizer_opengl import (
_check_raster_settings,
_convert_meshes_to_gl_ndc,
_parse_and_verify_image_size,
MeshRasterizerOpenGL,
)
from pytorch3d.structures import Pointclouds
from pytorch3d.structures.meshes import Meshes
......
......@@ -23,7 +23,6 @@ from pytorch3d.renderer import (
look_at_view_transform,
Materials,
MeshRasterizer,
MeshRasterizerOpenGL,
MeshRenderer,
MeshRendererWithFragments,
OrthographicCameras,
......@@ -44,6 +43,7 @@ from pytorch3d.renderer.mesh.shader import (
SplatterPhongShader,
TexturedSoftPhongShader,
)
from pytorch3d.renderer.opengl import MeshRasterizerOpenGL
from pytorch3d.structures.meshes import (
join_meshes_as_batch,
join_meshes_as_scene,
......
......@@ -14,7 +14,6 @@ from pytorch3d.renderer import (
HardGouraudShader,
Materials,
MeshRasterizer,
MeshRasterizerOpenGL,
MeshRenderer,
PointLights,
PointsRasterizationSettings,
......@@ -26,6 +25,7 @@ from pytorch3d.renderer import (
TexturesVertex,
)
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.opengl import MeshRasterizerOpenGL
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment