Commit 2c64635d authored by Krzysztof Chalupka's avatar Krzysztof Chalupka Committed by Facebook GitHub Bot
Browse files

Add type hints to MeshRenderer(WithFragments)

Reviewed By: bottler

Differential Revision: D36148049

fbshipit-source-id: 87ca3ea8d5b5a315418cc597b36fd0a1dffb1e00
parent ec9580a1
...@@ -4,10 +4,14 @@ ...@@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...structures.meshes import Meshes
from .rasterizer import MeshRasterizer
# A renderer class should be initialized with a # A renderer class should be initialized with a
# function for rasterization and a function for shading. # function for rasterization and a function for shading.
...@@ -32,7 +36,7 @@ class MeshRenderer(nn.Module): ...@@ -32,7 +36,7 @@ class MeshRenderer(nn.Module):
function. function.
""" """
def __init__(self, rasterizer, shader) -> None: def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.shader = shader self.shader = shader
...@@ -43,7 +47,7 @@ class MeshRenderer(nn.Module): ...@@ -43,7 +47,7 @@ class MeshRenderer(nn.Module):
self.shader.to(device) self.shader.to(device)
return self return self
def forward(self, meshes_world, **kwargs) -> torch.Tensor: def forward(self, meshes_world: Meshes, **kwargs) -> torch.Tensor:
""" """
Render a batch of images from a batch of meshes by rasterizing and then Render a batch of images from a batch of meshes by rasterizing and then
shading. shading.
...@@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module): ...@@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module):
depth = fragments.zbuf depth = fragments.zbuf
""" """
def __init__(self, rasterizer, shader) -> None: def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.shader = shader self.shader = shader
...@@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module): ...@@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module):
# Rasterizer and shader have submodules which are not of type nn.Module # Rasterizer and shader have submodules which are not of type nn.Module
self.rasterizer.to(device) self.rasterizer.to(device)
self.shader.to(device) self.shader.to(device)
return self
def forward(self, meshes_world, **kwargs): def forward(
self, meshes_world: Meshes, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Render a batch of images from a batch of meshes by rasterizing and then Render a batch of images from a batch of meshes by rasterizing and then
shading. shading.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment