"examples/vscode:/vscode.git/clone" did not exist on "8d4bb020565e404d9eb814150280147e4963a2ee"
Commit 48c9cec6 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2419 failed with stages
in 0 seconds
Metadata-Version: 2.2
Name: direct3d
Version: 1.0.0
Summary: Direct3D: Scalable Image-to-3D Generation via 3D Latent Diffusion Transformer
Requires-Python: >=3.10
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: numpy
Requires-Dist: cython
Requires-Dist: trimesh
Requires-Dist: diffusers
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary
LICENSE
README.md
setup.py
direct3d.egg-info/PKG-INFO
direct3d.egg-info/SOURCES.txt
direct3d.egg-info/dependency_links.txt
direct3d.egg-info/requires.txt
direct3d.egg-info/top_level.txt
\ No newline at end of file
torch
numpy
cython
trimesh
diffusers
import torch.nn as nn
from transformers import CLIPModel, AutoModel
from torchvision import transforms as T
class ClipImageEncoder(nn.Module):
def __init__(self, version="openai/clip-vit-large-patch14", img_size=224):
super().__init__()
encoder = CLIPModel.from_pretrained(version)
encoder = encoder.eval()
self.encoder = encoder
self.transform = T.Compose(
[
T.Resize(img_size, antialias=True),
T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
def forward(self, image):
image = self.transform(image)
embbed = self.encoder.vision_model(image).last_hidden_state
return embbed
class DinoEncoder(nn.Module):
def __init__(self, version="facebook/dinov2-large", img_size=224):
super().__init__()
encoder = AutoModel.from_pretrained(version)
encoder = encoder.eval()
self.encoder = encoder
self.transform = T.Compose(
[
T.Resize(img_size, antialias=True),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
def forward(self, image):
image = self.transform(image)
embbed = self.encoder(image).last_hidden_state
return embbed
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection, get_2d_sincos_pos_embed_from_grid
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.attention import FeedForward
class ClassCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim, class_emb_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256,
time_embed_dim=embedding_dim,
cond_proj_dim=class_emb_dim)
def forward(self, timestep, hidden_dtype, class_embedding=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype),
condition=class_embedding) # (N, D)
return timesteps_emb
class AdaLayerNormClassEmb(nn.Module):
def __init__(self, embedding_dim: int, class_emb_dim: int):
super().__init__()
self.emb = ClassCombinedTimestepSizeEmbeddings(
embedding_dim, class_emb_dim
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
class_embedding: torch.Tensor = None,
hidden_dtype: Optional[torch.dtype] = None,
):
embedded_timestep = self.emb(timestep,
class_embedding=class_embedding,
hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
if isinstance(base_size, int):
base_size = (base_size, base_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
class PatchEmbed(nn.Module):
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, (self.height, self.width), base_size=(self.height, self.width), interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=(height, width),
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class Attention(nn.Module):
def __init__(
self,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
out_bias: bool = True,
):
super().__init__()
self.inner_dim = dim_head * heads
self.use_bias = bias
self.dropout = dropout
self.heads = heads
self.to_q = nn.Linear(self.inner_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.inner_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.inner_dim, self.inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(self.inner_dim, self.inner_dim, bias=out_bias),
nn.Dropout(dropout)
])
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
):
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = hidden_states.shape[0] if encoder_hidden_states is None else encoder_hidden_states.shape[0]
query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states
class DiTBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
activation_fn: str = "geglu",
attention_bias: bool = False,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
pixel_hidden_states: Optional[torch.FloatTensor] = None,
):
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
hidden_states_len = norm_hidden_states.shape[1]
attn_output = self.attn1(
torch.cat([pixel_hidden_states, norm_hidden_states], dim=1),
)[:, -hidden_states_len:]
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class D3D_DiT(nn.Module):
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 72,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
semantic_channels: int = None,
pixel_channels: int = None,
interpolation_scale: float = 1.0,
gradient_checkpointing: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
if isinstance(sample_size, int):
sample_size = (sample_size, sample_size)
self.height = sample_size[0]
self.width = sample_size[1]
self.patch_size = patch_size
interpolation_scale = (
interpolation_scale if interpolation_scale is not None else max(min(self.config.sample_size) // 32, 1)
)
self.pos_embed = PatchEmbed(
height=sample_size[0],
width=sample_size[1],
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
self.transformer_blocks = nn.ModuleList(
[
DiTBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for d in range(num_layers)
]
)
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
self.adaln_single = AdaLayerNormClassEmb(inner_dim, semantic_channels)
self.semantic_projection = PixArtAlphaTextProjection(in_features=semantic_channels, hidden_size=inner_dim)
self.pixel_projection = PixArtAlphaTextProjection(in_features=pixel_channels, hidden_size=inner_dim)
self.gradient_checkpointing = gradient_checkpointing
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
pixel_hidden_states: Optional[torch.Tensor] = None,
):
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
timestep, embedded_timestep = self.adaln_single(
timestep, class_embedding=encoder_hidden_states[:, 0], hidden_dtype=hidden_states.dtype
)
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.semantic_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
pixel_hidden_states = self.pixel_projection(pixel_hidden_states)
pixel_hidden_states = pixel_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
encoder_hidden_states,
timestep,
pixel_hidden_states,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
pixel_hidden_states=pixel_hidden_states,
)
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
return output
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/vae.py
import trimesh
import itertools
import numpy as np
from tqdm import tqdm
from einops import rearrange, repeat
from skimage import measure
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from direct3d.utils.triplane import sample_from_planes, generate_planes
from diffusers.models.autoencoders.vae import UNetMidBlock2D
from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self):
x = self.mean + self.std * torch.randn_like(self.mean)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.mean(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=(1, 2, 3)):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class FourierEmbedder(nn.Module):
def __init__(self,
num_freqs: int = 6,
input_dim: int = 3):
super().__init__()
freq = 2.0 ** torch.arange(num_freqs)
self.register_buffer("freq", freq, persistent=False)
self.num_freqs = num_freqs
self.out_dim = input_dim * (num_freqs * 2 + 1)
def forward(self, x: torch.Tensor):
embed = (x[..., None].contiguous() * self.freq).view(*x.shape[:-1], -1)
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
class OccDecoder(nn.Module):
def __init__(self,
n_features: int,
hidden_dim: int = 64,
num_layers: int = 4,
activation: nn.Module = nn.ReLU,
final_activation: str = None):
super().__init__()
self.net = nn.Sequential(
nn.Linear(3 * n_features, hidden_dim),
activation(),
*itertools.chain(*[[
nn.Linear(hidden_dim, hidden_dim),
activation(),
] for _ in range(num_layers - 2)]),
nn.Linear(hidden_dim, 1),
)
self.final_activation = final_activation
def forward(self, sampled_features):
x = rearrange(sampled_features, "N_b N_t N_s C -> N_b N_s (N_t C)")
x = self.net(x)
if self.final_activation is None:
pass
elif self.final_activation == 'tanh':
x = torch.tanh(x)
elif self.final_activation == 'sigmoid':
x = torch.sigmoid(x)
else:
raise ValueError(f"Unknown final activation: {self.final_activation}")
return x[..., 0]
class Attention(nn.Module):
def __init__(self,
dim: int,
heads: int = 8,
dim_head: int = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.to_out = nn.Linear(inner_dim, inner_dim)
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
):
batch_size = hidden_states.shape[0]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = self.to_q(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = self.to_out(hidden_states)
return hidden_states
class TransformerBlock(nn.Module):
def __init__(self,
num_attention_heads: int,
attention_head_dim: int,
cross_attention: bool = False):
super().__init__()
inner_dim = attention_head_dim * num_attention_heads
self.norm1 = nn.LayerNorm(inner_dim)
if cross_attention:
self.norm1_c = nn.LayerNorm(inner_dim)
else:
self.norm1_c = None
self.attn = Attention(inner_dim, num_attention_heads, attention_head_dim)
self.norm2 = nn.LayerNorm(inner_dim)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, 4 * inner_dim),
nn.GELU(),
nn.Linear(4 * inner_dim, inner_dim),
)
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
if self.norm1_c is not None:
x = self.attn(self.norm1(x), self.norm1_c(y)) + x
else:
x = self.attn(self.norm1(x)) + x
x = x + self.mlp(self.norm2(x))
return x
class PointEncoder(nn.Module):
def __init__(self,
num_latents: int,
in_channels: int,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
gradient_checkpointing: bool = False):
super().__init__()
self.gradient_checkpointing = gradient_checkpointing
self.num_latents = num_latents
inner_dim = attention_head_dim * num_attention_heads
self.learnable_token = nn.Parameter(torch.randn((num_latents, inner_dim)) * 0.01)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.cross_attn = TransformerBlock(num_attention_heads, attention_head_dim, cross_attention=True)
self.self_attn = nn.ModuleList([
TransformerBlock(num_attention_heads, attention_head_dim) for _ in range(num_layers)
])
self.norm_out = nn.LayerNorm(inner_dim)
def forward(self, pc):
bs = pc.shape[0]
pc = self.proj_in(pc)
learnable_token = repeat(self.learnable_token, "m c -> b m c", b=bs)
if self.training and self.gradient_checkpointing:
latents = torch.utils.checkpoint.checkpoint(self.cross_attn, learnable_token, pc)
for block in self.self_attn:
latents = torch.utils.checkpoint.checkpoint(block, latents)
else:
latents = self.cross_attn(learnable_token, pc)
for block in self.self_attn:
latents = block(latents)
latents = self.norm_out(latents)
return latents
class TriplaneDecoder(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group",
mid_block_add_attention=True,
gradient_checkpointing: bool = False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock2D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=norm_type,
temb_channels=temb_channels,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "group":
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
else:
raise ValueError(f"Unsupported norm type: {norm_type}")
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, sample: torch.Tensor):
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
# middle
sample = torch.utils.checkpoint.checkpoint(
self.mid_block, sample
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(up_block, sample)
else:
# middle
sample = self.mid_block(sample)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class D3D_VAE(nn.Module):
def __init__(self,
triplane_res: int,
latent_dim: int = 0,
triplane_dim: int = 32,
num_freqs: int = 8,
num_attention_heads: int = 12,
attention_head_dim: int = 64,
num_encoder_layers: int = 8,
num_geodecoder_layers: int = 5,
final_activation: str = None,
block_out_channels=[128, 256, 512, 512],
mid_block_add_attention=True,
gradient_checkpointing: bool = False,
latents_scale: float = 1.0,
latents_shift: float = 0.0):
super().__init__()
self.gradient_checkpointing = gradient_checkpointing
self.triplane_res = triplane_res
self.num_latents = triplane_res ** 2 * 3
self.latent_shape = (latent_dim, triplane_res, 3 * triplane_res)
self.latents_scale = latents_scale
self.latents_shift = latents_shift
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs)
inner_dim = attention_head_dim * num_attention_heads
self.encoder = PointEncoder(
num_latents=self.num_latents,
in_channels=self.fourier_embedder.out_dim + 3,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=num_encoder_layers,
gradient_checkpointing=gradient_checkpointing,
)
self.latent_dim = latent_dim
self.pre_latent = nn.Conv2d(inner_dim, 2 * latent_dim, 1)
self.post_latent = nn.Conv2d(latent_dim, inner_dim, 1)
self.decoder = TriplaneDecoder(
in_channels=inner_dim,
out_channels=triplane_dim,
block_out_channels=block_out_channels,
mid_block_add_attention=mid_block_add_attention,
gradient_checkpointing=gradient_checkpointing,
)
self.plane_axes = generate_planes()
self.occ_decoder = OccDecoder(
n_features=triplane_dim,
num_layers=num_geodecoder_layers,
final_activation=final_activation,
)
def rollout(self, triplane):
triplane = rearrange(triplane, "N_b (N_t C) H_t W_t -> N_b C H_t (N_t W_t)", N_t=3)
return triplane
def unrollout(self, triplane):
triplane = rearrange(triplane, "N_b C H_t (N_t W_t) -> N_b N_t C H_t W_t", N_t=3)
return triplane
def encode(self,
pc: torch.FloatTensor,
feats: Optional[torch.FloatTensor] = None):
x = self.fourier_embedder(pc)
if feats is not None:
x = torch.cat((x, feats), dim=-1)
x = self.encoder(x)
x = rearrange(x, "N_b (N_t H_t W_t) C -> N_b (N_t C) H_t W_t",
N_t=3, H_t=self.triplane_res, W_t=self.triplane_res)
x = self.rollout(x)
moments = self.pre_latent(x)
posterior = DiagonalGaussianDistribution(moments)
latents = posterior.sample()
return latents, posterior
def decode(self, z, unrollout=False):
z = self.post_latent(z)
dec = self.decoder(z)
if unrollout:
dec = self.unrollout(dec)
return dec
def decode_mesh(self,
latents,
bounds: Union[Tuple[float], List[float], float] = 1.0,
voxel_resolution: int = 512,
mc_threshold: float = 0.0):
triplane = self.decode(latents, unrollout=True)
mesh = self.triplane2mesh(triplane,
bounds=bounds,
voxel_resolution=voxel_resolution,
mc_threshold=mc_threshold)
return mesh
def triplane2mesh(self,
latents: torch.FloatTensor,
bounds: Union[Tuple[float], List[float], float] = 1.0,
voxel_resolution: int = 512,
mc_threshold: float = 0.0,
chunk_size: int = 50000):
batch_size = len(latents)
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
bbox_length = bbox_max - bbox_min
x = torch.linspace(bbox_min[0], bbox_max[0], steps=int(voxel_resolution) + 1)
y = torch.linspace(bbox_min[1], bbox_max[1], steps=int(voxel_resolution) + 1)
z = torch.linspace(bbox_min[2], bbox_max[2], steps=int(voxel_resolution) + 1)
xs, ys, zs = torch.meshgrid(x, y, z, indexing='ij')
xyz = torch.stack((xs, ys, zs), dim=-1)
xyz = xyz.reshape(-1, 3)
grid_size = [int(voxel_resolution) + 1, int(voxel_resolution) + 1, int(voxel_resolution) + 1]
logits_total = []
for start in tqdm(range(0, xyz.shape[0], chunk_size), desc="Triplane Sampling:"):
positions = xyz[start:start + chunk_size].to(latents.device)
positions = repeat(positions, "p d -> b p d", b=batch_size)
triplane_features = sample_from_planes(self.plane_axes.to(latents.device),
latents, positions,
box_warp=2.0)
logits = self.occ_decoder(triplane_features)
logits_total.append(logits)
logits_total = torch.cat(logits_total, dim=1).view(
(batch_size, grid_size[0], grid_size[1], grid_size[2])).cpu().numpy()
meshes = []
for i in range(batch_size):
vertices, faces, _, _ = measure.marching_cubes(
logits_total[i],
mc_threshold,
method="lewiner"
)
vertices = vertices / grid_size * bbox_length + bbox_min
faces = faces[:, ::-1]
meshes.append(trimesh.Trimesh(vertices, faces))
return meshes
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
import os
from tqdm import tqdm
from PIL import Image
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from typing import Union, List, Optional
import torch
from direct3d.utils import instantiate_from_config, preprocess
from diffusers.utils.torch_utils import randn_tensor
class Direct3dPipeline(object):
def __init__(self,
vae,
dit,
semantic_encoder,
pixel_encoder,
scheduler):
self.vae = vae
self.dit = dit
self.semantic_encoder = semantic_encoder
self.pixel_encoder = pixel_encoder
self.scheduler = scheduler
def to(self, device):
self.device = torch.device(device)
self.vae.to(device)
self.dit.to(device)
self.semantic_encoder.to(device)
self.pixel_encoder.to(device)
@classmethod
def from_pretrained(cls,
pipeline_path):
if os.path.isdir(pipeline_path):
config_path = os.path.join(pipeline_path, 'config.yaml')
model_path = os.path.join(pipeline_path, 'model.ckpt')
else:
config_path = hf_hub_download(repo_id=pipeline_path, filename="config.yaml", repo_type="model")
model_path = hf_hub_download(repo_id=pipeline_path, filename="model.ckpt", repo_type="model")
cfg = OmegaConf.load(config_path)
state_dict = torch.load(model_path, map_location='cpu')
vae = instantiate_from_config(cfg.vae)
vae.load_state_dict(state_dict["vae"], strict=True)
dit = instantiate_from_config(cfg.dit)
dit.load_state_dict(state_dict["dit"], strict=True)
semantic_encoder = instantiate_from_config(cfg.semantic_encoder)
pixel_encoder = instantiate_from_config(cfg.pixel_encoder)
scheduler = instantiate_from_config(cfg.scheduler)
return cls(
vae=vae,
dit=dit,
semantic_encoder=semantic_encoder,
pixel_encoder=pixel_encoder,
scheduler=scheduler)
def prepare_image(self, image: Union[str, List[str], Image.Image, List[Image.Image]], rmbg: bool = True):
if not isinstance(image, list):
image = [image]
if isinstance(image[0], str):
image = [Image.open(img) for img in image]
image = [preprocess(img, rmbg=rmbg) for img in image]
image = torch.stack([img for img in image]).to(self.device)
return image
def encode_image(self, image: torch.Tensor, do_classifier_free_guidance: bool = True):
semantic_cond = self.semantic_encoder(image)
pixel_cond = self.pixel_encoder(image)
if do_classifier_free_guidance:
semantic_uncond = torch.zeros_like(semantic_cond)
pixel_uncond = torch.zeros_like(pixel_cond)
semantic_cond = torch.cat([semantic_uncond, semantic_cond], dim=0)
pixel_cond = torch.cat([pixel_uncond, pixel_cond], dim=0)
return semantic_cond, pixel_cond
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator):
shape = (
batch_size,
num_channels_latents,
height,
width,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(
self,
image: Union[str, List[str], Image.Image, List[Image.Image]] = None,
num_inference_steps: int = 50,
guidance_scale: float = 4.0,
generator: Optional[torch.Generator] = None,
mc_threshold: float = -2.0,
remove_background: bool = True,):
batch_size = len(image) if isinstance(image, list) else 1
do_classifier_free_guidance = guidance_scale > 0
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
image = self.prepare_image(image, remove_background)
semantic_cond, pixel_cond = self.encode_image(image, do_classifier_free_guidance)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels_latents=self.vae.latent_shape[0],
height=self.vae.latent_shape[1],
width=self.vae.latent_shape[2],
dtype=image.dtype,
device=self.device,
generator=generator,
)
extra_step_kwargs = {
"generator": generator
}
for i, t in enumerate(tqdm(timesteps, desc="Diffusion Sampling:")):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
t = t.expand(latent_model_input.shape[0])
noise_pred = self.dit(
hidden_states=latent_model_input,
timestep=t,
encoder_hidden_states=semantic_cond,
pixel_hidden_states=pixel_cond,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = 1. / self.vae.latents_scale * latents + self.vae.latents_shift
meshes = self.vae.decode_mesh(latents, mc_threshold=mc_threshold)
outputs = {"meshes": meshes, "latents": latents}
return outputs
\ No newline at end of file
from .util import instantiate_from_config, get_obj_from_str
from .image import preprocess
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment