from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block @dataclass class UNet1DOutput(BaseOutput): """ Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): Hidden states output. Output of last layer of model. """ sample: torch.FloatTensor class UNet1DModel(ModelMixin, ConfigMixin): r""" UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.) Parameters: sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : obj:`False`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. """ @register_to_config def __init__( self, sample_size: int = 65536, sample_rate: Optional[int] = None, in_channels: int = 2, out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", freq_shift: int = 0, flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), mid_block_type: str = "UNetMidBlock1D", up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), block_out_channels: Tuple[int] = (32, 32, 64), ): super().__init__() self.sample_size = sample_size # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection( embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) self.out_block = None # down output_channel = in_channels for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] if i == 0: input_channel += extra_in_channels down_block = get_down_block( down_block_type, in_channels=input_channel, out_channels=output_channel, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type=mid_block_type, mid_channels=block_out_channels[-1], in_channels=block_out_channels[-1], out_channels=None, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels up_block = get_up_block( up_block_type, in_channels=prev_output_channel, out_channels=output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # TODO(PVP, Nathan) placeholder for RL application to be merged shortly # Totally fine to add another layer with a if statement - no need for nn.Identity here def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[UNet1DOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time if len(timestep.shape) == 0: timestep = timestep[None] timestep_embed = self.time_proj(timestep)[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]) # 2. down down_block_res_samples = () for downsample_block in self.down_blocks: sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) down_block_res_samples += res_samples # 3. mid sample = self.mid_block(sample) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] sample = upsample_block(sample, res_samples) if not return_dict: return (sample,) return UNet1DOutput(sample=sample)