# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Union

import torch
from pytorch3d.implicitron.tools.config import (
    registry,
    ReplaceableBase,
    run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding

from .autodecoder import Autodecoder


class GlobalEncoderBase(ReplaceableBase):
    """
    A base class for implementing encoders of global frame-specific quantities.

    The latter includes e.g. the harmonic encoding of a frame timestamp
    (`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
    (`SequenceAutodecoder`).
    """

    def __init__(self) -> None:
        super().__init__()

    def get_encoding_dim(self):
        """
        Returns the dimensionality of the returned encoding.
        """
        raise NotImplementedError()

    def calc_squared_encoding_norm(self):
        """
        Calculates the squared norm of the encoding.
        """
        raise NotImplementedError()

    def forward(self, **kwargs) -> torch.Tensor:
        """
        Given a set of inputs to encode, generates a tensor containing the encoding.

        Returns:
            encoding: The tensor containing the global encoding.
        """
        raise NotImplementedError()


# TODO: probabilistic embeddings?
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):  # pyre-ignore: 13
    """
    A global encoder implementation which provides an autodecoder encoding
    of the frame's sequence identifier.
    """

    autodecoder: Autodecoder

    def __post_init__(self):
        super().__init__()
        run_auto_creation(self)

    def get_encoding_dim(self):
        return self.autodecoder.get_encoding_dim()

    def forward(
        self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
    ) -> torch.Tensor:

        # run dtype checks and pass sequence_name to self.autodecoder
        return self.autodecoder(sequence_name)

    def calc_squared_encoding_norm(self):
        return self.autodecoder.calc_squared_encoding_norm()


@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
    """
    A global encoder implementation which provides harmonic embeddings
    of each frame's timestamp.
    """

    n_harmonic_functions: int = 10
    append_input: bool = True
    time_divisor: float = 1.0

    def __post_init__(self):
        super().__init__()
        self._harmonic_embedding = HarmonicEmbedding(
            n_harmonic_functions=self.n_harmonic_functions,
            append_input=self.append_input,
        )

    def get_encoding_dim(self):
        return self._harmonic_embedding.get_output_dim(1)

    def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
        if frame_timestamp.shape[-1] != 1:
            raise ValueError("Frame timestamp's last dimensions should be one.")
        time = frame_timestamp / self.time_divisor
        return self._harmonic_embedding(time)  # pyre-ignore: 29

    def calc_squared_encoding_norm(self):
        return 0.0
