Commit 2a1de3b6 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

move LinearWithRepeat to pytorch3d

Summary: Move this simple layer from the NeRF project into pytorch3d.

Reviewed By: shapovalov

Differential Revision: D34126972

fbshipit-source-id: a9c6d6c3c1b662c1b844ea5d1b982007d4df83e6
parent ef21a6f6
...@@ -7,10 +7,9 @@ ...@@ -7,10 +7,9 @@
from typing import Tuple from typing import Tuple
import torch import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points
from .linear_with_repeat import LinearWithRepeat
def _xavier_init(linear): def _xavier_init(linear):
""" """
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter, init
class LinearWithRepeat(torch.nn.Linear): class LinearWithRepeat(torch.nn.Module):
""" """
if x has shape (..., k, n1) if x has shape (..., k, n1)
and y has shape (..., n2) and y has shape (..., n2)
...@@ -50,6 +52,40 @@ class LinearWithRepeat(torch.nn.Linear): ...@@ -50,6 +52,40 @@ class LinearWithRepeat(torch.nn.Linear):
and sent that through the Linear. and sent that through the Linear.
""" """
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
"""
Copied from torch.nn.Linear.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Copied from torch.nn.Linear.
"""
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
n1 = input[0].shape[-1] n1 = input[0].shape[-1]
output1 = F.linear(input[0], self.weight[:, :n1], self.bias) output1 = F.linear(input[0], self.weight[:, :n1], self.bias)
......
...@@ -73,8 +73,8 @@ from .points import ( ...@@ -73,8 +73,8 @@ from .points import (
from .utils import ( from .utils import (
TensorProperties, TensorProperties,
convert_to_tensors_and_broadcast, convert_to_tensors_and_broadcast,
ndc_to_grid_sample_coords,
ndc_grid_sample, ndc_grid_sample,
ndc_to_grid_sample_coords,
) )
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from typing import Any, Optional, Union, Tuple from typing import Any, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def test_simple(self):
x = torch.rand(4, 6, 7, 3)
y = torch.rand(4, 6, 4)
linear = torch.nn.Linear(7, 8)
torch.nn.init.xavier_uniform_(linear.weight.data)
linear.bias.data.uniform_()
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
expected = linear.forward(equivalent)
linear_with_repeat = LinearWithRepeat(7, 8)
linear_with_repeat.load_state_dict(linear.state_dict())
actual = linear_with_repeat.forward((x, y))
self.assertClose(actual, expected, rtol=1e-4)
...@@ -12,16 +12,16 @@ import torch ...@@ -12,16 +12,16 @@ import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.ops import eyes from pytorch3d.ops import eyes
from pytorch3d.renderer import ( from pytorch3d.renderer import (
PerspectiveCameras,
AlphaCompositor, AlphaCompositor,
PointsRenderer, PerspectiveCameras,
PointsRasterizationSettings, PointsRasterizationSettings,
PointsRasterizer, PointsRasterizer,
PointsRenderer,
) )
from pytorch3d.renderer.utils import ( from pytorch3d.renderer.utils import (
TensorProperties, TensorProperties,
ndc_to_grid_sample_coords,
ndc_grid_sample, ndc_grid_sample,
ndc_to_grid_sample_coords,
) )
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment