Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
from typing import *
import torch
import torch.nn as nn
from ..basic import VarLenTensor, SparseTensor
from ..linear import SparseLinear
from ..nonlinearity import SparseGELU
from ..attention import SparseMultiHeadAttention
from ...norm import LayerNorm32
class SparseFeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
SparseLinear(channels, int(channels * mlp_ratio)),
SparseGELU(approximate="tanh"),
SparseLinear(int(channels * mlp_ratio), channels),
)
def forward(self, x: VarLenTensor) -> VarLenTensor:
return self.mlp(x)
class SparseTransformerBlock(nn.Module):
"""
Sparse Transformer block (MSA + FFN).
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: SparseTensor) -> SparseTensor:
h = x.replace(self.norm1(x.feats))
h = self.attn(h)
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.mlp(h)
x = x + h
return x
def forward(self, x: SparseTensor) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN).
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
h = x.replace(self.norm1(x.feats))
h = self.self_attn(h)
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = self.mlp(h)
x = x + h
return x
def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
else:
return self._forward(x, context)
from typing import *
import torch
import torch.nn as nn
from ..basic import VarLenTensor, SparseTensor
from ..attention import SparseMultiHeadAttention
from ...norm import LayerNorm32
from .blocks import SparseFeedForwardNet
class ModulatedSparseTransformerBlock(nn.Module):
"""
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.attn(h)
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
else:
return self._forward(x, mod)
class ModulatedSparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.self_attn(h)
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
else:
return self._forward(x, mod, context)
import torch
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
"""
3D pixel shuffle.
"""
B, C, H, W, D = x.shape
C_ = C // scale_factor**3
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
return x
def patchify(x: torch.Tensor, patch_size: int):
"""
Patchify a tensor.
Args:
x (torch.Tensor): (N, C, *spatial) tensor
patch_size (int): Patch size
"""
DIM = x.dim() - 2
for d in range(2, DIM + 2):
assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
return x
def unpatchify(x: torch.Tensor, patch_size: int):
"""
Unpatchify a tensor.
Args:
x (torch.Tensor): (N, C, *spatial) tensor
patch_size (int): Patch size
"""
DIM = x.dim() - 2
assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
return x
from .blocks import *
from .modulated import *
\ No newline at end of file
from typing import *
import torch
import torch.nn as nn
from ..attention import MultiHeadAttention
from ..norm import LayerNorm32
class AbsolutePositionEmbedder(nn.Module):
"""
Embeds spatial positions into vector representations.
"""
def __init__(self, channels: int, in_channels: int = 3):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.freq_dim = channels // in_channels // 2
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = 1.0 / (10000 ** self.freqs)
def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
"""
Create sinusoidal position embeddings.
Args:
x: a 1-D Tensor of N indices
Returns:
an (N, D) Tensor of positional embeddings.
"""
self.freqs = self.freqs.to(x.device)
out = torch.outer(x, self.freqs)
out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): (N, D) tensor of spatial positions
"""
N, D = x.shape
assert D == self.in_channels, "Input dimension must match number of input channels"
embed = self._sin_cos_embedding(x.reshape(-1))
embed = embed.reshape(N, -1)
if embed.shape[1] < self.channels:
embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
return embed
class FeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(channels * mlp_ratio), channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class TransformerBlock(nn.Module):
"""
Transformer block (MSA + FFN).
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[int] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qkv_bias: bool = True,
ln_affine: bool = True,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.attn = MultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
h = self.attn(h, phases=phases)
x = x + h
h = self.norm2(x)
h = self.mlp(h)
x = x + h
return x
def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False)
else:
return self._forward(x, phases)
class TransformerCrossBlock(nn.Module):
"""
Transformer cross-attention block (MSA + MCA + FFN).
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.self_attn = MultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = MultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
h = self.self_attn(h, phases=phases)
x = x + h
h = self.norm2(x)
h = self.cross_attn(h, context)
x = x + h
h = self.norm3(x)
h = self.mlp(h)
x = x + h
return x
def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False)
else:
return self._forward(x, context, phases)
\ No newline at end of file
from typing import *
import torch
import torch.nn as nn
from ..attention import MultiHeadAttention
from ..norm import LayerNorm32
from .blocks import FeedForwardNet
class ModulatedTransformerBlock(nn.Module):
"""
Transformer block (MSA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.attn = MultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.attn(h, phases=phases)
h = h * gate_msa.unsqueeze(1)
x = x + h
h = self.norm2(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
h = h * gate_mlp.unsqueeze(1)
x = x + h
return x
def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False)
else:
return self._forward(x, mod, phases)
class ModulatedTransformerCrossBlock(nn.Module):
"""
Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.self_attn = MultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = MultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.self_attn(h, phases=phases)
h = h * gate_msa.unsqueeze(1)
x = x + h
h = self.norm2(x)
h = self.cross_attn(h, context)
x = x + h
h = self.norm3(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = self.mlp(h)
h = h * gate_mlp.unsqueeze(1)
x = x + h
return x
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False)
else:
return self._forward(x, mod, context, phases)
\ No newline at end of file
import torch
import torch.nn as nn
from ..modules import sparse as sp
MIX_PRECISION_MODULES = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
sp.SparseConv3d,
sp.SparseInverseConv3d,
sp.SparseLinear,
)
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, MIX_PRECISION_MODULES):
for p in l.parameters():
p.data = p.data.half()
def convert_module_to_bf16(l):
"""
Convert primitive modules to bfloat16.
"""
if isinstance(l, MIX_PRECISION_MODULES):
for p in l.parameters():
p.data = p.data.bfloat16()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, MIX_PRECISION_MODULES):
for p in l.parameters():
p.data = p.data.float()
def convert_module_to(l, dtype):
"""
Convert primitive modules to the given dtype.
"""
if isinstance(l, MIX_PRECISION_MODULES):
for p in l.parameters():
p.data = p.data.to(dtype)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def manual_cast(tensor, dtype):
"""
Cast if autocast is not enabled.
"""
if not torch.is_autocast_enabled():
return tensor.type(dtype)
return tensor
def str_to_dtype(dtype_str: str):
return {
'f16': torch.float16,
'fp16': torch.float16,
'float16': torch.float16,
'bf16': torch.bfloat16,
'bfloat16': torch.bfloat16,
'f32': torch.float32,
'fp32': torch.float32,
'float32': torch.float32,
}[dtype_str]
import importlib
__attributes = {
"Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
"Trellis2TexturingPipeline": "trellis2_texturing",
}
__submodules = ['samplers', 'rembg']
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
def from_pretrained(path: str):
"""
Load a pipeline from a model folder or a Hugging Face model hub.
Args:
path: The path to the model. Can be either local path or a Hugging Face model name.
"""
import os
import json
is_local = os.path.exists(f"{path}/pipeline.json")
if is_local:
config_file = f"{path}/pipeline.json"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, "pipeline.json")
with open(config_file, 'r') as f:
config = json.load(f)
return globals()[config['name']].from_pretrained(path)
# For PyLance
if __name__ == '__main__':
from . import samplers, rembg
from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
from .trellis2_texturing import Trellis2TexturingPipeline
from typing import *
import torch
import torch.nn as nn
from .. import models
class Pipeline:
"""
A base class for pipelines.
"""
def __init__(
self,
models: dict[str, nn.Module] = None,
):
if models is None:
return
self.models = models
for model in self.models.values():
model.eval()
@classmethod
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline":
"""
Load a pretrained model.
"""
import os
import json
is_local = os.path.exists(f"{path}/{config_file}")
if is_local:
config_file = f"{path}/{config_file}"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, config_file)
with open(config_file, 'r') as f:
args = json.load(f)['args']
_models = {}
for k, v in args['models'].items():
if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load:
continue
try:
_models[k] = models.from_pretrained(f"{path}/{v}")
except Exception as e:
_models[k] = models.from_pretrained(v)
new_pipeline = cls(_models)
new_pipeline._pretrained_args = args
return new_pipeline
@property
def device(self) -> torch.device:
if hasattr(self, '_device'):
return self._device
for model in self.models.values():
if hasattr(model, 'device'):
return model.device
for model in self.models.values():
if hasattr(model, 'parameters'):
return next(model.parameters()).device
raise RuntimeError("No device found.")
def to(self, device: torch.device) -> None:
for model in self.models.values():
model.to(device)
def cuda(self) -> None:
self.to(torch.device("cuda"))
def cpu(self) -> None:
self.to(torch.device("cpu"))
\ No newline at end of file
from typing import *
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
class BiRefNet:
def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"):
# transformers 5.x calls all_tied_weights_keys.keys() during model loading,
# but BiRefNet (trust_remote_code) was written for older transformers and doesn't
# define this attribute. Patch the base class before loading.
from transformers import PreTrainedModel
if not hasattr(PreTrainedModel, '_trellis2_patched'):
_method_name = '_move_missing_keys_from_meta_to_device' if hasattr(PreTrainedModel, '_move_missing_keys_from_meta_to_device') else '_move_missing_keys_from_meta_to_cpu'
_orig = getattr(PreTrainedModel, _method_name)
def _patched(self_model, *args, **kwargs):
if not hasattr(self_model, 'all_tied_weights_keys'):
self_model.all_tied_weights_keys = {}
return _orig(self_model, *args, **kwargs)
setattr(PreTrainedModel, _method_name, _patched)
PreTrainedModel._trellis2_patched = True
self.model = AutoModelForImageSegmentation.from_pretrained(
model_name, trust_remote_code=True
)
self.model.eval()
self.transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def to(self, device: str):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
def __call__(self, image: Image.Image) -> Image.Image:
image_size = image.size
input_images = self.transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
from .base import Sampler
from .flow_euler import (
FlowEulerSampler,
FlowEulerCfgSampler,
FlowEulerGuidanceIntervalSampler,
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment