Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..modules.norm import GroupNorm32, ChannelLayerNorm32
from ..modules.spatial import pixel_shuffle_3d
from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
"""
Return a normalization layer.
"""
if norm_type == "group":
return GroupNorm32(32, *args, **kwargs)
elif norm_type == "layer":
return ChannelLayerNorm32(*args, **kwargs)
else:
raise ValueError(f"Invalid norm type {norm_type}")
class ResBlock3d(nn.Module):
def __init__(
self,
channels: int,
out_channels: Optional[int] = None,
norm_type: Literal["group", "layer"] = "layer",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.norm1 = norm_layer(norm_type, channels)
self.norm2 = norm_layer(norm_type, self.out_channels)
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm1(x)
h = F.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = F.silu(h)
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
class DownsampleBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
mode: Literal["conv", "avgpool"] = "conv",
):
assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mode == "conv":
self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
elif mode == "avgpool":
assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "conv"):
return self.conv(x)
else:
return F.avg_pool3d(x, 2)
class UpsampleBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
mode: Literal["conv", "nearest"] = "conv",
):
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if mode == "conv":
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
elif mode == "nearest":
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(self, "conv"):
x = self.conv(x)
return pixel_shuffle_3d(x, 2)
else:
return F.interpolate(x, scale_factor=2, mode="nearest")
class SparseStructureEncoder(nn.Module):
"""
Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
Args:
in_channels (int): Channels of the input.
latent_channels (int): Channels of the latent representation.
num_res_blocks (int): Number of residual blocks at each resolution.
channels (List[int]): Channels of the encoder blocks.
num_res_blocks_middle (int): Number of residual blocks in the middle.
norm_type (Literal["group", "layer"]): Type of normalization layer.
use_fp16 (bool): Whether to use FP16.
"""
def __init__(
self,
in_channels: int,
latent_channels: int,
num_res_blocks: int,
channels: List[int],
num_res_blocks_middle: int = 2,
norm_type: Literal["group", "layer"] = "layer",
use_fp16: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.latent_channels = latent_channels
self.num_res_blocks = num_res_blocks
self.channels = channels
self.num_res_blocks_middle = num_res_blocks_middle
self.norm_type = norm_type
self.use_fp16 = use_fp16
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
self.blocks = nn.ModuleList([])
for i, ch in enumerate(channels):
self.blocks.extend([
ResBlock3d(ch, ch)
for _ in range(num_res_blocks)
])
if i < len(channels) - 1:
self.blocks.append(
DownsampleBlock3d(ch, channels[i+1])
)
self.middle_block = nn.Sequential(*[
ResBlock3d(channels[-1], channels[-1])
for _ in range(num_res_blocks_middle)
])
self.out_layer = nn.Sequential(
norm_layer(norm_type, channels[-1]),
nn.SiLU(),
nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
)
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.use_fp16 = True
self.dtype = torch.float16
self.blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.use_fp16 = False
self.dtype = torch.float32
self.blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
h = self.input_layer(x)
h = h.type(self.dtype)
for block in self.blocks:
h = block(h)
h = self.middle_block(h)
h = h.type(x.dtype)
h = self.out_layer(h)
mean, logvar = h.chunk(2, dim=1)
if sample_posterior:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
if return_raw:
return z, mean, logvar
return z
class SparseStructureDecoder(nn.Module):
"""
Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
Args:
out_channels (int): Channels of the output.
latent_channels (int): Channels of the latent representation.
num_res_blocks (int): Number of residual blocks at each resolution.
channels (List[int]): Channels of the decoder blocks.
num_res_blocks_middle (int): Number of residual blocks in the middle.
norm_type (Literal["group", "layer"]): Type of normalization layer.
use_fp16 (bool): Whether to use FP16.
"""
def __init__(
self,
out_channels: int,
latent_channels: int,
num_res_blocks: int,
channels: List[int],
num_res_blocks_middle: int = 2,
norm_type: Literal["group", "layer"] = "layer",
use_fp16: bool = False,
):
super().__init__()
self.out_channels = out_channels
self.latent_channels = latent_channels
self.num_res_blocks = num_res_blocks
self.channels = channels
self.num_res_blocks_middle = num_res_blocks_middle
self.norm_type = norm_type
self.use_fp16 = use_fp16
self.dtype = torch.float16 if use_fp16 else torch.float32
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
self.middle_block = nn.Sequential(*[
ResBlock3d(channels[0], channels[0])
for _ in range(num_res_blocks_middle)
])
self.blocks = nn.ModuleList([])
for i, ch in enumerate(channels):
self.blocks.extend([
ResBlock3d(ch, ch)
for _ in range(num_res_blocks)
])
if i < len(channels) - 1:
self.blocks.append(
UpsampleBlock3d(ch, channels[i+1])
)
self.out_layer = nn.Sequential(
norm_layer(norm_type, channels[-1]),
nn.SiLU(),
nn.Conv3d(channels[-1], out_channels, 3, padding=1)
)
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.use_fp16 = True
self.dtype = torch.float16
self.blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.use_fp16 = False
self.dtype = torch.float32
self.blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.input_layer(x)
h = h.type(self.dtype)
h = self.middle_block(h)
for block in self.blocks:
h = block(h)
h = h.type(x.dtype)
h = self.out_layer(h)
return h
from typing import *
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..modules.utils import convert_module_to, manual_cast, str_to_dtype
from ..modules.transformer import AbsolutePositionEmbedder
from ..modules import sparse as sp
from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
from .sparse_structure_flow import TimestepEmbedder
from .sparse_elastic_mixin import SparseTransformerElasticMixin
class SLatFlowModel(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
cond_channels: int,
out_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "ape",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
dtype: str = 'float32',
use_checkpoint: bool = False,
share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = str_to_dtype(dtype)
self.t_embedder = TimestepEmbedder(model_channels)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 6 * model_channels, bias=True)
)
if pe_mode == "ape":
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
self.input_layer = sp.SparseLinear(in_channels, model_channels)
self.blocks = nn.ModuleList([
ModulatedSparseTransformerCrossBlock(
model_channels,
cond_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
)
for _ in range(num_blocks)
])
self.out_layer = sp.SparseLinear(model_channels, out_channels)
self.initialize_weights()
self.convert_to(self.dtype)
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to(self, dtype: torch.dtype) -> None:
"""
Convert the torso of the model to the specified dtype.
"""
self.dtype = dtype
self.blocks.apply(partial(convert_module_to, dtype=dtype))
def initialize_weights(self) -> None:
if self.initialization == 'vanilla':
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
if self.share_mod:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
else:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
elif self.initialization == 'scaled':
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels)))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Scaled init for to_out and ffn2
def _scaled_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
for block in self.blocks:
block.self_attn.to_out.apply(_scaled_init)
block.cross_attn.to_out.apply(_scaled_init)
block.mlp.mlp[2].apply(_scaled_init)
# Initialize input layer to make the initial representation have variance 1
nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels))
nn.init.zeros_(self.input_layer.bias)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
if self.share_mod:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
else:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def forward(
self,
x: sp.SparseTensor,
t: torch.Tensor,
cond: Union[torch.Tensor, List[torch.Tensor]],
concat_cond: Optional[sp.SparseTensor] = None,
**kwargs
) -> sp.SparseTensor:
if concat_cond is not None:
x = sp.sparse_cat([x, concat_cond], dim=-1)
if isinstance(cond, list):
cond = sp.VarLenTensor.from_tensor_list(cond)
h = self.input_layer(x)
h = manual_cast(h, self.dtype)
t_emb = self.t_embedder(t)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
cond = manual_cast(cond, self.dtype)
if self.pe_mode == "ape":
pe = self.pos_embedder(h.coords[:, 1:])
h = h + manual_cast(pe, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond)
h = manual_cast(h, x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return h
class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
"""
SLat Flow Model with elastic memory management.
Used for training with low VRAM.
"""
pass
from .full_attn import *
from .modules import *
from .rope import *
from typing import *
BACKEND = 'flash_attn'
DEBUG = False
def __from_env():
import os
global BACKEND
global DEBUG
env_attn_backend = os.environ.get('ATTN_BACKEND')
env_attn_debug = os.environ.get('ATTN_DEBUG')
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa', 'naive']:
BACKEND = env_attn_backend
if env_attn_debug is not None:
DEBUG = env_attn_debug == '1'
print(f"[ATTENTION] Using backend: {BACKEND}")
__from_env()
def set_backend(backend: Literal['xformers', 'flash_attn']):
global BACKEND
BACKEND = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
from typing import *
import torch
import math
from . import config
__all__ = [
'scaled_dot_product_attention',
]
def _naive_sdpa(q, k, v):
"""
Naive implementation of scaled dot product attention.
"""
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
scale_factor = 1 / math.sqrt(q.size(-1))
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
out = attn_weight @ v
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
return out
@overload
def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
"""
Apply scaled dot product attention.
Args:
qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
"""
...
@overload
def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
"""
Apply scaled dot product attention.
Args:
q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
"""
...
@overload
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Apply scaled dot product attention.
Args:
q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
Note:
k and v are assumed to have the same coordinate map.
"""
...
def scaled_dot_product_attention(*args, **kwargs):
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
device = qkv.device
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
device = q.device
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
device = q.device
if config.BACKEND == 'xformers':
if 'xops' not in globals():
import xformers.ops as xops
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = xops.memory_efficient_attention(q, k, v)
elif config.BACKEND == 'flash_attn':
if 'flash_attn' not in globals():
import flash_attn
if num_all_args == 1:
out = flash_attn.flash_attn_qkvpacked_func(qkv)
elif num_all_args == 2:
out = flash_attn.flash_attn_kvpacked_func(q, kv)
elif num_all_args == 3:
out = flash_attn.flash_attn_func(q, k, v)
elif config.BACKEND == 'flash_attn_3':
if 'flash_attn_3' not in globals():
import flash_attn_interface as flash_attn_3
if num_all_args == 1:
out = flash_attn_3.flash_attn_qkvpacked_func(qkv)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = flash_attn_3.flash_attn_func(q, k, v)
elif num_all_args == 3:
out = flash_attn_3.flash_attn_func(q, k, v)
elif config.BACKEND == 'sdpa':
if 'sdpa' not in globals():
from torch.nn.functional import scaled_dot_product_attention as sdpa
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
out = sdpa(q, k, v) # [N, H, L, C]
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
elif config.BACKEND == 'naive':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = _naive_sdpa(q, k, v)
else:
raise ValueError(f"Unknown attention module: {config.BACKEND}")
return out
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from .full_attn import scaled_dot_product_attention
from .rope import RotaryPositionEmbedder
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
ctx_channels: Optional[int]=None,
type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
):
super().__init__()
assert channels % num_heads == 0
assert type in ["self", "cross"], f"Invalid attention type: {type}"
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
if attn_mode == "windowed":
raise NotImplementedError("Windowed attention is not yet implemented")
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
else:
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
self.to_out = nn.Linear(channels, channels)
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
B, L, C = x.shape
if self._type == "self":
qkv = self.to_qkv(x)
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
if self.attn_mode == "full":
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
assert phases is not None, "Phases must be provided for RoPE"
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(qkv)
elif self.attn_mode == "windowed":
raise NotImplementedError("Windowed attention is not yet implemented")
else:
Lkv = context.shape[1]
q = self.to_q(x)
kv = self.to_kv(context)
q = q.reshape(B, L, self.num_heads, -1)
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=2)
k = self.k_rms_norm(k)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(q, kv)
h = h.reshape(B, L, -1)
h = self.to_out(h)
return h
from typing import *
import torch
import torch.nn as nn
class RotaryPositionEmbedder(nn.Module):
def __init__(
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0)
):
super().__init__()
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
self.freq_dim = head_dim // 2 // dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
@staticmethod
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
def forward(self, indices: torch.Tensor) -> torch.Tensor:
"""
Args:
indices (torch.Tensor): [..., N, C] tensor of spatial positions
"""
assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device),
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
)], dim=-1)
return phases
\ No newline at end of file
from typing import *
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import DINOv3ViTModel
import numpy as np
from PIL import Image
class DinoV2FeatureExtractor:
"""
Feature extractor for DINOv2 models.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
self.model.eval()
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.transform(image).cuda()
features = self.model(image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
class DinoV3FeatureExtractor:
"""
Feature extractor for DINOv3 models.
"""
def __init__(self, model_name: str, image_size=512):
self.model_name = model_name
self.model = DINOv3ViTModel.from_pretrained(model_name)
self.model.eval()
self.image_size = image_size
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
def extract_features(self, image: torch.Tensor) -> torch.Tensor:
# transformers 5.x: DINOv3ViTModel is a backbone, use its forward() directly
if hasattr(self.model, 'model'):
output = self.model(image)
hidden_states = output.last_hidden_state
return F.layer_norm(hidden_states, hidden_states.shape[-1:])
# older transformers: manual layer iteration
image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.model.embeddings(image, bool_masked_pos=None)
position_embeddings = self.model.rope_embeddings(image)
for i, layer_module in enumerate(self.model.layer):
hidden_states = layer_module(
hidden_states,
position_embeddings=position_embeddings,
)
return F.layer_norm(hidden_states, hidden_states.shape[-1:])
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.transform(image).cuda()
features = self.extract_features(image)
return features
import torch
import torch.nn as nn
from .utils import manual_cast
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
class GroupNorm32(nn.GroupNorm):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
class ChannelLayerNorm32(LayerNorm32):
def forward(self, x: torch.Tensor) -> torch.Tensor:
DIM = x.dim()
x = x.permute(0, *range(2, DIM), 1).contiguous()
x = super().forward(x)
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
return x
\ No newline at end of file
from . import config
import importlib
__attributes = {
'VarLenTensor': 'basic',
'varlen_cat': 'basic',
'varlen_unbind': 'basic',
'SparseTensor': 'basic',
'sparse_cat': 'basic',
'sparse_unbind': 'basic',
'SparseGroupNorm': 'norm',
'SparseLayerNorm': 'norm',
'SparseGroupNorm32': 'norm',
'SparseLayerNorm32': 'norm',
'SparseReLU': 'nonlinearity',
'SparseSiLU': 'nonlinearity',
'SparseGELU': 'nonlinearity',
'SparseActivation': 'nonlinearity',
'SparseLinear': 'linear',
'sparse_scaled_dot_product_attention': 'attention',
'SerializeMode': 'attention',
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
'sparse_windowed_scaled_dot_product_cross_attention': 'attention',
'SparseRotaryPositionEmbedder': 'attention',
'SparseMultiHeadAttention': 'attention',
'SparseConv3d': 'conv',
'SparseInverseConv3d': 'conv',
'SparseDownsample': 'spatial',
'SparseUpsample': 'spatial',
'SparseSubdivide': 'spatial',
'SparseSpatial2Channel': 'spatial',
'SparseChannel2Spatial': 'spatial',
'sparse_nearest_interpolate': 'spatial',
'sparse_trilinear_interpolate': 'spatial',
'encode_seq': 'serialize',
'decode_seq': 'serialize',
}
__submodules = ['transformer', 'conv']
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .basic import *
from .norm import *
from .nonlinearity import *
from .linear import *
from .attention import *
from .conv import *
from .spatial import *
from .serialize import *
import transformer
import conv
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment