base.py 837 Bytes
Newer Older
Ruilong Li's avatar
Ruilong Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from abc import abstractmethod
from typing import Tuple

import torch
import torch.nn as nn


class BaseRadianceField(nn.Module):
    """An abstract RadianceField class (supports both 2D and 3D).

    The key functions to be implemented are:
    - forward(positions, directions, masks): returns rgb and density.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

    @abstractmethod
    def forward(
        self,
        positions: torch.Tensor,
        directions: torch.Tensor = None,
        masks: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns {rgb, density}."""
        raise NotImplementedError()

    @abstractmethod
    def query_density(self, positions: torch.Tensor) -> torch.Tensor:
        """Returns {density}."""
        raise NotImplementedError()