Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
from .base import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
from typing import *
import torch
from ..voxel import Voxel
import cumesh
from flex_gemm.ops.grid_sample import grid_sample_3d
from ...utils.pipeline_logger import get_logger, log_mesh, elapsed
class Mesh:
def __init__(self,
vertices,
faces,
vertex_attrs=None
):
self.vertices = vertices.float()
self.faces = faces.int()
self.vertex_attrs = vertex_attrs
@property
def device(self):
return self.vertices.device
def to(self, device, non_blocking=False):
return Mesh(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
)
def cuda(self, non_blocking=False):
return self.to('cuda', non_blocking=non_blocking)
def cpu(self):
return self.to('cpu')
def fill_holes(self, max_hole_perimeter=3e-2):
import os, numpy as np
L = get_logger()
log_mesh(self.vertices, self.faces, "fill_holes:before")
vertices = self.vertices.cuda()
faces = self.faces.cuda()
# ------------------------------------------------------------------ #
# Debug helpers: per-step .obj dump + stats print
# ------------------------------------------------------------------ #
_dbg_dir = os.environ.get("CUMESH_DEBUG_DIR", "cumesh_debug")
_dbg_step = [0]
def _snap(label, v_tensor, f_tensor):
return
"""Dump vertex/face data to an OBJ and print min/max/nan stats."""
v = v_tensor.detach().cpu().float().numpy() # [N, 3]
f = f_tensor.detach().cpu().int().numpy() # [M, 3]
step = _dbg_step[0]
_dbg_step[0] += 1
vmin = v.min(axis=0) if len(v) else [float('nan')]*3
vmax = v.max(axis=0) if len(v) else [float('nan')]*3
all_zero_v = bool((v == 0).all()) if len(v) else True
all_zero_f = bool((f == 0).all()) if len(f) else True
nan_v = bool(np.isnan(v).any())
print(f"[CUMESH_DBG] step={step:02d} {label}")
print(f" verts : {v.shape[0]} min={vmin} max={vmax} all_zero={all_zero_v} nan={nan_v}")
print(f" faces : {f.shape[0]} all_zero={all_zero_f}")
os.makedirs(_dbg_dir, exist_ok=True)
obj_path = os.path.join(_dbg_dir, f"step{step:02d}_{label.replace(':', '_').replace('/', '_')}.obj")
with open(obj_path, "w") as fp:
fp.write(f"# step={step} {label}\n")
fp.write(f"# {v.shape[0]} vertices, {f.shape[0]} faces\n\n")
for row in v:
fp.write(f"v {row[0]:.6f} {row[1]:.6f} {row[2]:.6f}\n")
fp.write("\n")
for row in f:
fp.write(f"f {row[0]+1} {row[1]+1} {row[2]+1}\n")
print(f" -> {obj_path}")
def _snap_mesh(label):
return
"""Read current CuMesh state and dump it."""
v, f = mesh.read()
_snap(label, v, f)
# ------------------------------------------------------------------ #
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
_snap("00_after_init", vertices, faces)
mesh.get_edges()
_snap_mesh("01_after_get_edges")
mesh.get_boundary_info()
L.info(f" {elapsed()} fill_holes: num_boundaries={mesh.num_boundaries}")
_snap_mesh("02_after_get_boundary_info")
if mesh.num_boundaries == 0:
L.info(f" {elapsed()} fill_holes: no boundaries, skipping")
return
mesh.get_vertex_edge_adjacency()
_snap_mesh("03_after_get_vertex_edge_adjacency")
mesh.get_vertex_boundary_adjacency()
_snap_mesh("04_after_get_vertex_boundary_adjacency")
mesh.get_manifold_boundary_adjacency()
_snap_mesh("05_after_get_manifold_boundary_adjacency")
mesh.read_manifold_boundary_adjacency()
_snap_mesh("06_after_read_manifold_boundary_adjacency")
mesh.get_boundary_connected_components()
_snap_mesh("07_after_get_boundary_connected_components")
mesh.get_boundary_loops()
L.info(f" {elapsed()} fill_holes: num_boundary_loops={mesh.num_boundary_loops}")
_snap_mesh("08_after_get_boundary_loops")
if mesh.num_boundary_loops == 0:
return
mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
_snap_mesh("09_after_fill_holes")
new_vertices, new_faces = mesh.read()
_snap("10_final_read", new_vertices, new_faces)
log_mesh(new_vertices, new_faces, "fill_holes:after")
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
def remove_faces(self, face_mask: torch.Tensor):
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.remove_faces(face_mask)
new_vertices, new_faces = mesh.read()
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
def simplify(self, target=1000000, verbose: bool=False, options: dict={}):
L = get_logger()
log_mesh(self.vertices, self.faces, f"simplify:before(target={target})")
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.simplify(target, verbose=verbose, options=options)
new_vertices, new_faces = mesh.read()
log_mesh(new_vertices, new_faces, "simplify:after")
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
class TextureFilterMode:
CLOSEST = 0
LINEAR = 1
class TextureWrapMode:
CLAMP_TO_EDGE = 0
REPEAT = 1
MIRRORED_REPEAT = 2
class AlphaMode:
OPAQUE = 0
MASK = 1
BLEND = 2
class Texture:
def __init__(
self,
image: torch.Tensor,
filter_mode: TextureFilterMode = TextureFilterMode.LINEAR,
wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT
):
self.image = image
self.filter_mode = filter_mode
self.wrap_mode = wrap_mode
def to(self, device, non_blocking=False):
return Texture(
self.image.to(device, non_blocking=non_blocking),
self.filter_mode,
self.wrap_mode,
)
class PbrMaterial:
def __init__(
self,
base_color_texture: Optional[Texture] = None,
base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0],
metallic_texture: Optional[Texture] = None,
metallic_factor: float = 1.0,
roughness_texture: Optional[Texture] = None,
roughness_factor: float = 1.0,
alpha_texture: Optional[Texture] = None,
alpha_factor: float = 1.0,
alpha_mode: AlphaMode = AlphaMode.OPAQUE,
alpha_cutoff: float = 0.5,
):
self.base_color_texture = base_color_texture
self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3]
self.metallic_texture = metallic_texture
self.metallic_factor = metallic_factor
self.roughness_texture = roughness_texture
self.roughness_factor = roughness_factor
self.alpha_texture = alpha_texture
self.alpha_factor = alpha_factor
self.alpha_mode = alpha_mode
self.alpha_cutoff = alpha_cutoff
def to(self, device, non_blocking=False):
return PbrMaterial(
base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None,
base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking),
metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None,
metallic_factor=self.metallic_factor,
roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None,
roughness_factor=self.roughness_factor,
alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None,
alpha_factor=self.alpha_factor,
alpha_mode=self.alpha_mode,
alpha_cutoff=self.alpha_cutoff,
)
class MeshWithPbrMaterial(Mesh):
def __init__(self,
vertices,
faces,
material_ids,
uv_coords,
materials: List[PbrMaterial],
):
self.vertices = vertices.float()
self.faces = faces.int()
self.material_ids = material_ids # [M]
self.uv_coords = uv_coords # [M, 3, 2]
self.materials = materials
self.layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
def to(self, device, non_blocking=False):
return MeshWithPbrMaterial(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.material_ids.to(device, non_blocking=non_blocking),
self.uv_coords.to(device, non_blocking=non_blocking),
[material.to(device, non_blocking=non_blocking) for material in self.materials],
)
class MeshWithVoxel(Mesh, Voxel):
def __init__(self,
vertices: torch.Tensor,
faces: torch.Tensor,
origin: list,
voxel_size: float,
coords: torch.Tensor,
attrs: torch.Tensor,
voxel_shape: torch.Size,
layout: Dict = {},
):
self.vertices = vertices.float()
self.faces = faces.int()
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.voxel_shape = voxel_shape
self.layout = layout
def to(self, device, non_blocking=False):
return MeshWithVoxel(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.origin.tolist(),
self.voxel_size,
self.coords.to(device, non_blocking=non_blocking),
self.attrs.to(device, non_blocking=non_blocking),
self.voxel_shape,
self.layout,
)
def query_attrs(self, xyz):
grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3)
vertex_attrs = grid_sample_3d(
self.attrs,
torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1),
self.voxel_shape,
grid,
mode='trilinear'
)[0]
return vertex_attrs
def query_vertex_attrs(self):
return self.query_attrs(self.vertices)
from .voxel_model import Voxel
\ No newline at end of file
from typing import Dict
import torch
class Voxel:
def __init__(
self,
origin: list,
voxel_size: float,
coords: torch.Tensor = None,
attrs: torch.Tensor = None,
layout: Dict = {},
device: torch.device = 'cuda'
):
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.layout = layout
self.device = device
@property
def position(self):
return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
def split_attrs(self):
return {
k: self.attrs[:, self.layout[k]]
for k in self.layout
}
def save(self, path):
# lazy import
if 'o_voxel' not in globals():
import o_voxel
o_voxel.io.write(
path,
self.coords,
self.split_attrs(),
)
def load(self, path):
# lazy import
if 'o_voxel' not in globals():
import o_voxel
coord, attrs = o_voxel.io.read(path)
self.coords = coord.int().to(self.device)
self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device)
# build layout
start = 0
self.layout = {}
for k in attrs:
self.layout[k] = slice(start, start + attrs[k].shape[1])
start += attrs[k].shape[1]
import importlib
__attributes = {
'BasicTrainer': 'basic',
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
'ShapeVaeTrainer': 'vae.shape_vae',
'PbrVaeTrainer': 'vae.pbr_vae',
'FlowMatchingTrainer': 'flow_matching.flow_matching',
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'MultiImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'DinoV2FeatureExtractor': 'flow_matching.mixins.image_conditioned',
'DinoV3FeatureExtractor': 'flow_matching.mixins.image_conditioned',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .basic import BasicTrainer
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
from .vae.shape_vae import ShapeVaeTrainer
from .vae.pbr_vae import PbrVaeTrainer
from .flow_matching.flow_matching import (
FlowMatchingTrainer,
FlowMatchingCFGTrainer,
TextConditionedFlowMatchingCFGTrainer,
ImageConditionedFlowMatchingCFGTrainer,
)
from .flow_matching.sparse_flow_matching import (
SparseFlowMatchingTrainer,
SparseFlowMatchingCFGTrainer,
TextConditionedSparseFlowMatchingCFGTrainer,
ImageConditionedSparseFlowMatchingCFGTrainer,
)
from .flow_matching.mixins.image_conditioned import (
DinoV2FeatureExtractor,
DinoV3FeatureExtractor,
)
from abc import abstractmethod
import os
import time
import json
import copy
import threading
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from torchvision import utils
from torch.utils.tensorboard import SummaryWriter
from .utils import *
from ..utils.general_utils import *
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
from ..utils.dist_utils import *
from ..utils import grad_clip_utils, elastic_utils
class BasicTrainer:
"""
Trainer for basic training loop.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
mix_precision_mode (str):
- None: No mixed precision.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
mix_precision_dtype (str): Mixed precision dtype.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
parallel_mode (str): Parallel mode. Options are 'ddp'.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
"""
def __init__(self,
models,
dataset,
*,
output_dir,
load_dir,
step,
max_steps,
batch_size=None,
batch_size_per_gpu=None,
batch_split=None,
optimizer={},
lr_scheduler=None,
elastic=None,
grad_clip=None,
ema_rate=0.9999,
fp16_mode=None,
mix_precision_mode='inflat_all',
mix_precision_dtype='float16',
fp16_scale_growth=1e-3,
parallel_mode='ddp',
finetune_ckpt=None,
log_param_stats=False,
prefetch_data=True,
snapshot_batch_size=4,
i_print=1000,
i_log=500,
i_sample=10000,
i_save=10000,
i_ddpcheck=10000,
**kwargs
):
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
self.models = models
self.dataset = dataset
self.batch_split = batch_split if batch_split is not None else 1
self.max_steps = max_steps
self.optimizer_config = optimizer
self.lr_scheduler_config = lr_scheduler
self.elastic_controller_config = elastic
self.grad_clip = grad_clip
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
if fp16_mode is not None:
mix_precision_dtype = 'float16'
mix_precision_mode = fp16_mode
self.mix_precision_mode = mix_precision_mode
self.mix_precision_dtype = str_to_dtype(mix_precision_dtype)
self.fp16_scale_growth = fp16_scale_growth
self.parallel_mode = parallel_mode
self.log_param_stats = log_param_stats
self.prefetch_data = prefetch_data
self.snapshot_batch_size = snapshot_batch_size
self.log = []
if self.prefetch_data:
self._data_prefetched = None
self.output_dir = output_dir
self.i_print = i_print
self.i_log = i_log
self.i_sample = i_sample
self.i_save = i_save
self.i_ddpcheck = i_ddpcheck
if dist.is_initialized():
# Multi-GPU params
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.local_rank = dist.get_rank() % torch.cuda.device_count()
self.is_master = self.rank == 0
else:
# Single-GPU params
self.world_size = 1
self.rank = 0
self.local_rank = 0
self.is_master = True
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
self.init_models_and_more(**kwargs)
self.prepare_dataloader(**kwargs)
# Load checkpoint
self.step = 0
if load_dir is not None and step is not None:
self.load(load_dir, step)
elif finetune_ckpt is not None:
self.finetune_from(finetune_ckpt)
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
if self.parallel_mode == 'ddp' and self.world_size > 1:
self.check_ddp()
if self.is_master:
print('\n\nTrainer initialized.')
print(self)
def __str__(self):
lines = []
lines.append(self.__class__.__name__)
lines.append(f' - Models:')
for name, model in self.models.items():
lines.append(f' - {name}: {model.__class__.__name__}')
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
lines.append(f' - Dataloader:')
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
lines.append(f' - Num workers: {self.dataloader.num_workers}')
lines.append(f' - Number of steps: {self.max_steps}')
lines.append(f' - Number of GPUs: {self.world_size}')
lines.append(f' - Batch size: {self.batch_size}')
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
lines.append(f' - Batch split: {self.batch_split}')
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
if self.lr_scheduler_config is not None:
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
if self.elastic_controller_config is not None:
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
if self.grad_clip is not None:
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
lines.append(f' - EMA rate: {self.ema_rate}')
lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}')
lines.append(f' - Mixed precision mode: {self.mix_precision_mode}')
if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}')
lines.append(f' - Parallel mode: {self.parallel_mode}')
return '\n'.join(lines)
@property
def device(self):
for _, model in self.models.items():
if hasattr(model, 'device'):
return model.device
return next(list(self.models.values())[0].parameters()).device
def init_models_and_more(self, **kwargs):
"""
Initialize models and more.
"""
if self.world_size > 1:
# Prepare distributed data parallel
self.training_models = {
name: DDP(
model,
device_ids=[self.local_rank],
output_device=self.local_rank,
bucket_cap_mb=128,
find_unused_parameters=False
)
for name, model in self.models.items()
}
else:
self.training_models = self.models
# Build master params
self.model_params = sum(
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
, [])
if self.mix_precision_mode == 'amp':
self.master_params = self.model_params
if self.mix_precision_dtype == torch.float16:
self.scaler = torch.GradScaler()
elif self.mix_precision_mode == 'inflat_all':
self.master_params = make_master_params(self.model_params)
if self.mix_precision_dtype == torch.float16:
self.log_scale = 20.0
elif self.mix_precision_mode is None:
self.master_params = self.model_params
else:
raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.')
# Build EMA params
if self.is_master:
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
# Initialize optimizer
if hasattr(torch.optim, self.optimizer_config['name']):
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
else:
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
# Initalize learning rate scheduler
if self.lr_scheduler_config is not None:
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
else:
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
# Initialize elastic memory controller
if self.elastic_controller_config is not None:
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
for model in self.models.values():
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
model.register_memory_controller(self.elastic_controller)
# Initialize gradient clipper
if self.grad_clip is not None:
if isinstance(self.grad_clip, (float, int)):
self.grad_clip = float(self.grad_clip)
else:
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = ResumableSampler(
self.dataset,
shuffle=True,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
def _master_params_to_state_dicts(self, master_params):
"""
Convert master params to dict of state_dicts.
"""
if self.mix_precision_mode == 'inflat_all':
master_params = unflatten_master_params(self.model_params, master_params)
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
master_params_names = sum(
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
, [])
for i, (model_name, param_name) in enumerate(master_params_names):
state_dicts[model_name][param_name] = master_params[i]
return state_dicts
def _state_dicts_to_master_params(self, master_params, state_dicts):
"""
Convert a state_dict to master params.
"""
master_params_names = sum(
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
, [])
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
if self.mix_precision_mode == 'inflat_all':
model_params_to_master_params(params, master_params)
else:
for i, param in enumerate(params):
master_params[i].data.copy_(param.data)
def load(self, load_dir, step=0):
"""
Load a checkpoint.
Should be called by all processes.
"""
if self.is_master:
print(f'\nLoading checkpoint from step {step}...', end='')
model_ckpts = {}
for name, model in self.models.items():
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
model_ckpts[name] = model_ckpt
model.load_state_dict(model_ckpt)
self._state_dicts_to_master_params(self.master_params, model_ckpts)
del model_ckpts
if self.is_master:
for i, ema_rate in enumerate(self.ema_rate):
ema_ckpts = {}
for name, model in self.models.items():
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
ema_ckpts[name] = ema_ckpt
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
del ema_ckpts
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
self.step = misc_ckpt['step']
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
self.scaler.load_state_dict(misc_ckpt['scaler'])
elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
self.log_scale = misc_ckpt['log_scale']
if self.lr_scheduler_config is not None:
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
if self.elastic_controller_config is not None:
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
del misc_ckpt
if self.world_size > 1:
dist.barrier()
if self.is_master:
print(' Done.')
if self.world_size > 1:
self.check_ddp()
def save(self, non_blocking=True):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
assert self.is_master, 'save() should be called only by the rank 0 process.'
print(f'\nSaving checkpoint at step {self.step}...', end='')
model_ckpts = self._master_params_to_state_dicts(self.master_params)
for name, model_ckpt in model_ckpts.items():
model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving
if non_blocking:
threading.Thread(
target=torch.save,
args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')),
).start()
else:
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
for i, ema_rate in enumerate(self.ema_rate):
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
for name, ema_ckpt in ema_ckpts.items():
ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving
if non_blocking:
threading.Thread(
target=torch.save,
args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')),
).start()
else:
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
misc_ckpt = {
'optimizer': self.optimizer.state_dict(),
'step': self.step,
'data_sampler': self.data_sampler.state_dict(),
}
if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
misc_ckpt['scaler'] = self.scaler.state_dict()
elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
misc_ckpt['log_scale'] = self.log_scale
if self.lr_scheduler_config is not None:
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
if self.elastic_controller_config is not None:
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
if non_blocking:
threading.Thread(
target=torch.save,
args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')),
).start()
else:
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
print(' Done.')
def finetune_from(self, finetune_ckpt):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
if self.is_master:
print('\nFinetuning from:')
for name, path in finetune_ckpt.items():
print(f' - {name}: {path}')
model_ckpts = {}
for name, model in self.models.items():
model_state_dict = model.state_dict()
if name in finetune_ckpt:
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
for k, v in model_ckpt.items():
if k not in model_state_dict:
if self.is_master:
print(f'Warning: {k} not found in model_state_dict, skipped.')
model_ckpt[k] = None
elif model_ckpt[k].shape != model_state_dict[k].shape:
if self.is_master:
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
model_ckpt[k] = model_state_dict[k]
model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
model_ckpts[name] = model_ckpt
model.load_state_dict(model_ckpt)
else:
if self.is_master:
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
model_ckpts[name] = model_state_dict
self._state_dicts_to_master_params(self.master_params, model_ckpts)
if self.is_master:
for i, ema_rate in enumerate(self.ema_rate):
self._state_dicts_to_master_params(self.ema_params[i], model_ckpts)
del model_ckpts
if self.world_size > 1:
dist.barrier()
if self.is_master:
print('Done.')
if self.world_size > 1:
self.check_ddp()
@abstractmethod
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
"""
Run a snapshot of the model.
"""
pass
@torch.no_grad()
def visualize_sample(self, sample):
"""
Convert a sample to an image.
"""
if hasattr(self.dataset, 'visualize_sample'):
return self.dataset.visualize_sample(sample)
else:
return sample
@torch.no_grad()
def snapshot_dataset(self, num_samples=100, batch_size=4):
"""
Sample images from the dataset.
"""
dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=batch_size,
num_workers=1,
shuffle=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
save_cfg = {}
for i in range(0, num_samples, batch_size):
data = next(iter(dataloader))
data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()}
data = recursive_to_device(data, self.device)
vis = self.visualize_sample(data)
if isinstance(vis, dict):
for k, v in vis.items():
if f'dataset_{k}' not in save_cfg:
save_cfg[f'dataset_{k}'] = []
save_cfg[f'dataset_{k}'].append(v)
else:
if 'dataset' not in save_cfg:
save_cfg['dataset'] = []
save_cfg['dataset'].append(vis)
for name, image in save_cfg.items():
utils.save_image(
torch.cat(image, dim=0),
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
@torch.no_grad()
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
"""
Sample images from the model.
NOTE: This function should be called by all processes.
"""
if self.is_master:
print(f'\nSampling {num_samples} images...', end='')
if suffix is None:
suffix = f'step{self.step:07d}'
# Assign tasks
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
with amp_context():
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
# Preprocess images
for key in list(samples.keys()):
if samples[key]['type'] == 'sample':
vis = self.visualize_sample(samples[key]['value'])
if isinstance(vis, dict):
for k, v in vis.items():
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
del samples[key]
else:
samples[key] = {'value': vis, 'type': 'image'}
# Gather results
if self.world_size > 1:
for key in samples.keys():
samples[key]['value'] = samples[key]['value'].contiguous()
if self.is_master:
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
else:
all_images = []
dist.gather(samples[key]['value'], all_images, dst=0)
if self.is_master:
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
# Save images
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
for key in samples.keys():
if samples[key]['type'] == 'image':
utils.save_image(
samples[key]['value'],
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
elif samples[key]['type'] == 'number':
min = samples[key]['value'].min()
max = samples[key]['value'].max()
images = (samples[key]['value'] - min) / (max - min)
images = utils.make_grid(
images,
nrow=int(np.sqrt(num_samples)),
normalize=False,
)
save_image_with_notes(
images,
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
notes=f'{key} min: {min}, max: {max}',
)
if self.is_master:
print(' Done.')
def update_ema(self):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
for i, ema_rate in enumerate(self.ema_rate):
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
def check_ddp(self):
"""
Check if DDP is working properly.
Should be called by all process.
"""
if self.is_master:
print('\nPerforming DDP check...')
if self.is_master:
print('Checking if parameters are consistent across processes...')
dist.barrier()
try:
for p in self.master_params:
# split to avoid OOM
for i in range(0, p.numel(), 10000000):
sub_size = min(10000000, p.numel() - i)
sub_p = p.detach().view(-1)[i:i+sub_size]
# gather from all processes
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
dist.all_gather(sub_p_gather, sub_p)
# check if equal
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
except AssertionError as e:
if self.is_master:
print(f'\n\033[91mError: {e}\033[0m')
print('DDP check failed.')
raise e
dist.barrier()
if self.is_master:
print('Done.')
@abstractmethod
def training_losses(**mb_data):
"""
Compute training losses.
"""
pass
def load_data(self):
"""
Load data.
"""
if self.prefetch_data:
if self._data_prefetched is None:
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
data = self._data_prefetched
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
else:
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
if isinstance(data, dict):
if self.batch_split == 1:
data_list = [data]
else:
batch_size = list(data.values())[0].shape[0]
data_list = [
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
for i in range(self.batch_split)
]
elif isinstance(data, list):
data_list = data
else:
raise ValueError('Data must be a dict or a list of dicts.')
return data_list
def run_step(self, data_list):
"""
Run a training step.
"""
step_log = {'loss': {}, 'status': {}}
amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
# Train
losses = []
statuses = []
elastic_controller_logs = []
zero_grad(self.model_params)
for i, mb_data in enumerate(data_list):
## sync at the end of each batch split
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
with nested_contexts(*sync_contexts), elastic_controller_context():
with amp_context():
loss, status = self.training_losses(**mb_data)
l = loss['loss'] / len(data_list)
## backward
if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
self.scaler.scale(l).backward()
elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
scaled_l = l * (2 ** self.log_scale)
scaled_l.backward()
else:
l.backward()
## log
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
if self.elastic_controller_config is not None:
elastic_controller_logs.append(self.elastic_controller.log())
## gradient clip
if self.grad_clip is not None:
if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
self.scaler.unscale_(self.optimizer)
elif self.mix_precision_mode == 'inflat_all':
model_grads_to_master_grads(self.model_params, self.master_params)
if self.mix_precision_dtype == torch.float16:
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
if isinstance(self.grad_clip, float):
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
else:
grad_norm = self.grad_clip(self.master_params)
if torch.isfinite(grad_norm):
statuses[-1]['grad_norm'] = grad_norm.item()
## step
if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
prev_scale = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
elif self.mix_precision_mode == 'inflat_all':
if self.mix_precision_dtype == torch.float16:
prev_scale = 2 ** self.log_scale
if not any(not p.grad.isfinite().all() for p in self.model_params):
if self.grad_clip is None:
model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
self.optimizer.step()
master_params_to_model_params(self.model_params, self.master_params)
self.log_scale += self.fp16_scale_growth
else:
self.log_scale -= 1
else:
prev_scale = 1.0
if self.grad_clip is None:
model_grads_to_master_grads(self.model_params, self.master_params)
if not any(not p.grad.isfinite().all() for p in self.master_params):
self.optimizer.step()
master_params_to_model_params(self.model_params, self.master_params)
else:
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
else:
prev_scale = 1.0
if not any(not p.grad.isfinite().all() for p in self.model_params):
self.optimizer.step()
else:
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
## adjust learning rate
if self.lr_scheduler_config is not None:
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
# Logs
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
if self.elastic_controller_config is not None:
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
if self.grad_clip is not None:
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
# Check grad and norm of each param
if self.log_param_stats:
param_norms = {}
param_grads = {}
for model_name, model in self.models.items():
for name, param in model.named_parameters():
if param.requires_grad:
param_norms[f'{model_name}.{name}'] = param.norm().item()
if param.grad is not None and torch.isfinite(param.grad).all():
param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale
step_log['param_norms'] = param_norms
step_log['param_grads'] = param_grads
# Update exponential moving average
if self.is_master:
self.update_ema()
return step_log
def save_logs(self):
log_str = '\n'.join([
f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log
])
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
log_file.write(log_str + '\n')
# show with mlflow
log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))]
log_show = dict_reduce(log_show, lambda x: np.mean(x))
log_show = dict_flatten(log_show, sep='/')
for key, value in log_show.items():
self.writer.add_scalar(key, value, self.step)
self.log = []
def check_abort(self):
"""
Check if training should be aborted due to certain conditions.
"""
# 1. If log_scale in inflat_all mode is less than 0
if self.mix_precision_dtype == torch.float16 and \
self.mix_precision_mode == 'inflat_all' and \
self.log_scale < 0:
if self.is_master:
print ('\n\n\033[91m')
print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.')
print ('This indicates that the model is diverging. You should look into the model and the data.')
print ('\033[0m')
self.save(non_blocking=False)
self.save_logs()
if self.world_size > 1:
dist.barrier()
raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.')
def run(self):
"""
Run training.
"""
if self.is_master:
print('\nStarting training...')
self.snapshot_dataset(batch_size=self.snapshot_batch_size)
if self.step == 0:
self.snapshot(suffix='init', batch_size=self.snapshot_batch_size)
else: # resume
self.snapshot(suffix=f'resume_step{self.step:07d}', batch_size=self.snapshot_batch_size)
time_last_print = 0.0
time_elapsed = 0.0
while self.step < self.max_steps:
time_start = time.time()
data_list = self.load_data()
step_log = self.run_step(data_list)
time_end = time.time()
time_elapsed += time_end - time_start
self.step += 1
# Print progress
if self.is_master and self.step % self.i_print == 0:
speed = self.i_print / (time_elapsed - time_last_print) * 3600
columns = [
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
f'Elapsed: {time_elapsed / 3600:.2f} h',
f'Speed: {speed:.2f} steps/h',
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
]
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
time_last_print = time_elapsed
# Check ddp
if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
self.check_ddp()
# Sample images
if self.step % self.i_sample == 0:
self.snapshot()
if self.is_master:
self.log.append((self.step, {}))
# Log time
self.log[-1][1]['time'] = {
'step': time_end - time_start,
'elapsed': time_elapsed,
}
# Log losses
if step_log is not None:
self.log[-1][1].update(step_log)
# Log scale
if self.mix_precision_dtype == torch.float16:
if self.mix_precision_mode == 'amp':
self.log[-1][1]['scale'] = self.scaler.get_scale()
elif self.mix_precision_mode == 'inflat_all':
self.log[-1][1]['log_scale'] = self.log_scale
# Save log
if self.step % self.i_log == 0:
self.save_logs()
# Save checkpoint
if self.step % self.i_save == 0:
self.save()
# Check abort
self.check_abort()
self.snapshot(suffix='final', batch_size=self.snapshot_batch_size)
if self.world_size > 1:
dist.barrier()
if self.is_master:
self.writer.close()
print('Training finished.')
def profile(self, wait=2, warmup=3, active=5):
"""
Profile the training loop.
"""
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
profile_memory=True,
with_stack=True,
) as prof:
for _ in range(wait + warmup + active):
self.run_step()
prof.step()
from typing import *
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
from ..basic import BasicTrainer
from ...pipelines import samplers
from ...utils.general_utils import dict_reduce
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
from .mixins.text_conditioned import TextConditionedMixin
from .mixins.image_conditioned import ImageConditionedMixin
class FlowMatchingTrainer(BasicTrainer):
"""
Trainer for diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def __init__(
self,
*args,
t_schedule: dict = {
'name': 'logitNormal',
'args': {
'mean': 0.0,
'std': 1.0,
}
},
sigma_min: float = 1e-5,
**kwargs
):
super().__init__(*args, **kwargs)
self.t_schedule = t_schedule
self.sigma_min = sigma_min
def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
t: The [N] tensor of diffusion steps [0-1].
noise: If specified, use this noise instead of generating new noise.
Returns:
x_t, the noisy version of x_0 under timestep t.
"""
if noise is None:
noise = torch.randn_like(x_0)
assert noise.shape == x_0.shape, "noise must have same shape as x_0"
t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
return x_t
def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""
Get original image from noisy version under timestep t.
"""
assert noise.shape == x_t.shape, "noise must have same shape as x_t"
t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
return x_0
def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Compute the velocity of the diffusion process at time t.
"""
return (1 - self.sigma_min) * noise - x_0
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
return {'cond': cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerSampler(self.sigma_min)
def vis_cond(self, **kwargs):
"""
Visualize the conditioning data.
"""
return {}
def sample_t(self, batch_size: int) -> torch.Tensor:
"""
Sample timesteps.
"""
if self.t_schedule['name'] == 'uniform':
t = torch.rand(batch_size)
elif self.t_schedule['name'] == 'logitNormal':
mean = self.t_schedule['args']['mean']
std = self.t_schedule['args']['std']
t = torch.sigmoid(torch.randn(batch_size) * std + mean)
else:
raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
return t
def training_losses(
self,
x_0: torch.Tensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise = torch.randn_like(x_0)
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
x_t = self.diffuse(x_0, t, noise=noise)
cond = self.get_cond(cond, **kwargs)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
target = self.get_v(x_0, noise, t)
terms = edict()
terms["mse"] = F.mse_loss(pred, target)
terms["loss"] = terms["mse"]
# log loss with time bins
mse_per_instance = np.array([
F.mse_loss(pred[i], target[i]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
sampler = self.get_sampler()
sample_gt = []
sample = []
cond_vis = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
noise = torch.randn_like(data['x_0'])
sample_gt.append(data['x_0'])
cond_vis.append(self.vis_cond(**data))
del data['x_0']
args = self.get_inference_cond(**data)
res = sampler.sample(
self.models['denoiser'],
noise=noise,
**args,
steps=50, guidance_strength=3.0, verbose=verbose,
)
sample.append(res.samples)
sample_gt = torch.cat(sample_gt, dim=0)
sample = torch.cat(sample, dim=0)
sample_dict = {
'sample_gt': {'value': sample_gt, 'type': 'sample'},
'sample': {'value': sample, 'type': 'sample'},
}
sample_dict.update(dict_reduce(cond_vis, None, {
'value': lambda x: torch.cat(x, dim=0),
'type': lambda x: x[0],
}))
return sample_dict
class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
"""
Trainer for diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
"""
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
"""
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass
import torch
import numpy as np
from ....utils.general_utils import dict_foreach
from ....pipelines import samplers
class ClassifierFreeGuidanceMixin:
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
super().__init__(*args, **kwargs)
self.p_uncond = p_uncond
def get_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
if self.p_uncond > 0:
# randomly drop the class label
def get_batch_size(cond):
if isinstance(cond, torch.Tensor):
return cond.shape[0]
elif isinstance(cond, list):
return len(cond)
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
B = get_batch_size(ref_cond)
def select(cond, neg_cond, mask):
if isinstance(cond, torch.Tensor):
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
return torch.where(mask, neg_cond, cond)
elif isinstance(cond, list):
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
mask = list(np.random.rand(B) < self.p_uncond)
if not isinstance(cond, dict):
cond = select(cond, neg_cond, mask)
else:
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
return cond
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data for inference.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerCfgSampler(self.sigma_min)
from typing import *
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import DINOv3ViTModel
import numpy as np
from PIL import Image
from ....utils import dist_utils
class DinoV2FeatureExtractor:
"""
Feature extractor for DINOv2 models.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True)
self.model.eval()
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.transform(image).cuda()
features = self.model(image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
class DinoV3FeatureExtractor:
"""
Feature extractor for DINOv3 models.
"""
def __init__(self, model_name: str, image_size=512):
self.model_name = model_name
self.model = DINOv3ViTModel.from_pretrained(model_name)
self.model.eval()
self.image_size = image_size
self.transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def to(self, device):
self.model.to(device)
def cuda(self):
self.model.cuda()
def cpu(self):
self.model.cpu()
def extract_features(self, image: torch.Tensor) -> torch.Tensor:
image = image.to(self.model.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.model.embeddings(image, bool_masked_pos=None)
position_embeddings = self.model.rope_embeddings(image)
for i, layer_module in enumerate(self.model.layer):
hidden_states = layer_module(
hidden_states,
position_embeddings=position_embeddings,
)
return F.layer_norm(hidden_states, hidden_states.shape[-1:])
@torch.no_grad()
def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Extract features from the image.
Args:
image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images.
Returns:
A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.transform(image).cuda()
features = self.extract_features(image)
return features
class ImageConditionedMixin:
"""
Mixin for image-conditioned models.
Args:
image_cond_model: The image conditioning model.
"""
def __init__(self, *args, image_cond_model: dict, **kwargs):
super().__init__(*args, **kwargs)
self.image_cond_model_config = image_cond_model
self.image_cond_model = None # the model is init lazily
def _init_image_cond_model(self):
"""
Initialize the image conditioning model.
"""
with dist_utils.local_master_first():
self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {}))
self.image_cond_model.cuda()
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Encode the image.
"""
if self.image_cond_model is None:
self._init_image_cond_model()
features = self.image_cond_model(image)
return features
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_image(cond)
kwargs['neg_cond'] = torch.zeros_like(cond)
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_image(cond)
kwargs['neg_cond'] = torch.zeros_like(cond)
cond = super().get_inference_cond(cond, **kwargs)
return cond
def vis_cond(self, cond, **kwargs):
"""
Visualize the conditioning data.
"""
return {'image': {'value': cond, 'type': 'image'}}
class MultiImageConditionedMixin:
"""
Mixin for multiple-image-conditioned models.
Args:
image_cond_model: The image conditioning model.
"""
def __init__(self, *args, image_cond_model: dict, **kwargs):
super().__init__(*args, **kwargs)
self.image_cond_model_config = image_cond_model
self.image_cond_model = None # the model is init lazily
def _init_image_cond_model(self):
"""
Initialize the image conditioning model.
"""
with dist_utils.local_master_first():
self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {}))
@torch.no_grad()
def encode_images(self, images: Union[List[torch.Tensor], List[List[Image.Image]]]) -> List[torch.Tensor]:
"""
Encode the image.
"""
if self.image_cond_model is None:
self._init_image_cond_model()
seqlen = [len(i) for i in images]
images = torch.cat(images, dim=0) if isinstance(images[0], torch.Tensor) else sum(images, [])
features = self.image_cond_model(images)
features = torch.split(features, seqlen)
features = [feature.reshape(-1, feature.shape[-1]) for feature in features]
return features
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_images(cond)
kwargs['neg_cond'] = [
torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond))
]
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_images(cond)
kwargs['neg_cond'] = [
torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond))
]
cond = super().get_inference_cond(cond, **kwargs)
return cond
def vis_cond(self, cond, **kwargs):
"""
Visualize the conditioning data.
"""
H, W = cond[0].shape[-2:]
vis = []
for images in cond:
canvas = torch.zeros(3, H * 2, W * 2, device=images.device, dtype=images.dtype)
for i, image in enumerate(images):
if i == 4:
break
kh = i // 2
kw = i % 2
canvas[:, kh*H:(kh+1)*H, kw*W:(kw+1)*W] = image
vis.append(canvas)
vis = torch.stack(vis)
return {'image': {'value': vis, 'type': 'image'}}
from typing import *
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import torch
from transformers import AutoTokenizer, CLIPTextModel
from ....utils import dist_utils
class TextConditionedMixin:
"""
Mixin for text-conditioned models.
Args:
text_cond_model: The text conditioning model.
"""
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
super().__init__(*args, **kwargs)
self.text_cond_model_name = text_cond_model
self.text_cond_model = None # the model is init lazily
def _init_text_cond_model(self):
"""
Initialize the text conditioning model.
"""
# load model
with dist_utils.local_master_first():
model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
model.eval()
model = model.cuda()
self.text_cond_model = {
'model': model,
'tokenizer': tokenizer,
}
self.text_cond_model['null_cond'] = self.encode_text([''])
@torch.no_grad()
def encode_text(self, text: List[str]) -> torch.Tensor:
"""
Encode the text.
"""
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
if self.text_cond_model is None:
self._init_text_cond_model()
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
tokens = encoding['input_ids'].cuda()
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
return embeddings
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_text(cond)
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_text(cond)
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
cond = super().get_inference_cond(cond, **kwargs)
return cond
from typing import *
import os
import copy
import functools
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
from ...modules import sparse as sp
from ...utils.general_utils import dict_reduce
from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler
from .flow_matching import FlowMatchingTrainer
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
from .mixins.text_conditioned import TextConditionedMixin
from .mixins.image_conditioned import ImageConditionedMixin, MultiImageConditionedMixin
class SparseFlowMatchingTrainer(FlowMatchingTrainer):
"""
Trainer for sparse diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = BalancedResumableSampler(
self.dataset,
shuffle=True,
batch_size=self.batch_size_per_gpu,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
def training_losses(
self,
x_0: sp.SparseTensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x ... x C] sparse tensor of the inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise = x_0.replace(torch.randn_like(x_0.feats))
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
x_t = self.diffuse(x_0, t, noise=noise)
cond = self.get_cond(cond, **kwargs)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
target = self.get_v(x_0, noise, t)
terms = edict()
terms["mse"] = F.mse_loss(pred.feats, target.feats)
terms["loss"] = terms["mse"]
# log loss with time bins
mse_per_instance = np.array([
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=num_samples,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
data = next(iter(dataloader))
# inference
sampler = self.get_sampler()
sample = []
cond_vis = []
for i in range(0, num_samples, batch_size):
batch_data = {k: v[i:i+batch_size] for k, v in data.items()}
batch_data = recursive_to_device(batch_data, 'cuda')
noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats))
cond_vis.append(self.vis_cond(**batch_data))
del batch_data['x_0']
args = self.get_inference_cond(**batch_data)
res = sampler.sample(
self.models['denoiser'],
noise=noise,
**args,
steps=12, guidance_strength=3.0, verbose=verbose,
)
sample.append(res.samples)
sample = sp.sparse_cat(sample)
sample_gt = {k: v for k, v in data.items()}
sample = {k: v if k != 'x_0' else sample for k, v in data.items()}
sample_dict = {
'sample_gt': {'value': sample_gt, 'type': 'sample'},
'sample': {'value': sample, 'type': 'sample'},
}
sample_dict.update(dict_reduce(cond_vis, None, {
'value': lambda x: torch.cat(x, dim=0),
'type': lambda x: x[0],
}))
return sample_dict
class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
"""
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass
class MultiImageConditionedSparseFlowMatchingCFGTrainer(MultiImageConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass
import torch
import torch.nn as nn
# FP16 utils
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def str_to_dtype(dtype_str: str):
return {
'f16': torch.float16,
'fp16': torch.float16,
'float16': torch.float16,
'bf16': torch.bfloat16,
'bfloat16': torch.bfloat16,
'f32': torch.float32,
'fp32': torch.float32,
'float32': torch.float32,
}[dtype_str]
def make_master_params(model_params):
"""
Copy model parameters into a inflated tensor of full-precision parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def model_params_to_master_params(model_params, master_params):
"""
Copy the model parameter data into the master parameters.
"""
master_params[0].detach().copy_(
_flatten_dense_tensors([param.detach().float() for param in model_params])
)
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
for param, master_param in zip(
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
):
param.detach().copy_(master_param)
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)
def zero_grad(model_params):
for param in model_params:
if param.grad is not None:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
# LR Schedulers
from torch.optim.lr_scheduler import LambdaLR
class LinearWarmupLRScheduler(LambdaLR):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
self.warmup_steps = warmup_steps
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
return float(current_step + 1) / self.warmup_steps
return 1.0
\ No newline at end of file
from typing import *
import os
import copy
import functools
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import utils3d
from easydict import EasyDict as edict
from ..basic import BasicTrainer
from ...modules import sparse as sp
from ...renderers import MeshRenderer
from ...representations import Mesh, MeshWithPbrMaterial, MeshWithVoxel
from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
class PbrVaeTrainer(BasicTrainer):
"""
Trainer for PBR attributes VAE
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type.
lambda_kl (float): KL loss weight.
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
"""
def __init__(
self,
*args,
loss_type: str = 'l1',
lambda_kl: float = 1e-6,
lambda_ssim: float = 0.2,
lambda_lpips: float = 0.2,
lambda_render: float = 1.0,
render_resolution: float = 1024,
camera_randomization_config: dict = {
'radius_range': [2, 100],
},
**kwargs
):
super().__init__(*args, **kwargs)
self.loss_type = loss_type
self.lambda_kl = lambda_kl
self.lambda_ssim = lambda_ssim
self.lambda_lpips = lambda_lpips
self.lambda_render = lambda_render
self.camera_randomization_config = camera_randomization_config
self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device)
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = BalancedResumableSampler(
self.dataset,
shuffle=True,
batch_size=self.batch_size_per_gpu,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
def _randomize_camera(self, num_samples: int):
# sample radius and fov
r_min, r_max = self.camera_randomization_config['radius_range']
k_min = 1 / r_max**2
k_max = 1 / r_min**2
ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min
radius = 1 / torch.sqrt(ks)
fov = 2 * torch.arcsin(0.5 / radius)
origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1)
# build camera
extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device))
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
near = [np.random.uniform(r - 1, r) for r in radius.tolist()]
return {
'extrinsics': extrinsics,
'intrinsics': intrinsics,
'near': near,
}
def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List,
) -> Dict[str, torch.Tensor]:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
Returns:
a dict with
base_color : [N x 3 x H x W] tensor of base color.
metallic : [N x 1 x H x W] tensor of metallic.
roughness : [N x 1 x H x W] tensor of roughness.
alpha : [N x 1 x H x W] tensor of alpha.
"""
ret = {k : [] for k in ['base_color', 'metallic', 'roughness', 'alpha']}
for i, rep in enumerate(reps):
self.renderer.rendering_options['near'] = near[i]
self.renderer.rendering_options['far'] = near[i] + 2
out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=['attr'])
for k in out_dict:
ret[k].append(out_dict[k])
for k in ret:
ret[k] = torch.stack(ret[k])
return ret
def training_losses(
self,
x: sp.SparseTensor,
mesh: List[MeshWithPbrMaterial] = None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses.
Args:
x (SparseTensor): Input sparse tensor for pbr materials.
mesh (List[MeshWithPbrMaterial]): The list of meshes with PBR materials.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z, mean, logvar = self.training_models['encoder'](x, sample_posterior=True, return_raw=True)
y = self.training_models['decoder'](z)
terms = edict(loss = 0.0)
# direct regression
if self.loss_type == 'l1':
terms["l1"] = l1_loss(x.feats, y.feats)
terms["loss"] = terms["loss"] + terms["l1"]
elif self.loss_type == 'l2':
terms["l2"] = l2_loss(x.feats, y.feats)
terms["loss"] = terms["loss"] + terms["l2"]
else:
raise ValueError(f'Invalid loss type {self.loss_type}')
# rendering loss
if self.lambda_render != 0.0:
recon = [MeshWithVoxel(
m.vertices,
m.faces,
[-0.5, -0.5, -0.5],
1 / self.dataset.resolution,
v.coords[:, 1:],
v.feats * 0.5 + 0.5,
torch.Size([*v.shape, *v.spatial_shape]),
layout={
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
) for m, v in zip(mesh, y)]
cameras = self._randomize_camera(len(mesh))
gt_renders = self._render_batch(mesh, **cameras)
pred_renders = self._render_batch(recon, **cameras)
gt_base_color = gt_renders['base_color']
pred_base_color = pred_renders['base_color']
gt_mra = torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1)
pred_mra = torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1)
terms['render/base_color/ssim'] = 1 - ssim(pred_base_color, gt_base_color)
terms['render/base_color/lpips'] = lpips(pred_base_color, gt_base_color)
terms['render/mra/ssim'] = 1 - ssim(pred_mra, gt_mra)
terms['render/mra/lpips'] = lpips(pred_mra, gt_mra)
terms['loss'] = terms['loss'] + \
self.lambda_render * (self.lambda_ssim * terms['render/base_color/ssim'] + self.lambda_lpips * terms['render/base_color/lpips'] + \
self.lambda_ssim * terms['render/mra/ssim'] + self.lambda_lpips * terms['render/mra/lpips'])
# KL regularization
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=1,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
dataloader.dataset.with_mesh = True
# inference
gts = []
recons = []
self.models['encoder'].eval()
self.models['decoder'].eval()
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
args = {k: v[:batch] for k, v in data.items()}
args = recursive_to_device(args, self.device)
z = self.models['encoder'](args['x'])
y = self.models['decoder'](z)
gts.extend(args['mesh'])
recons.extend([MeshWithVoxel(
m.vertices,
m.faces,
[-0.5, -0.5, -0.5],
1 / self.dataset.resolution,
v.coords[:, 1:],
v.feats * 0.5 + 0.5,
torch.Size([*v.shape, *v.spatial_shape]),
layout={
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
) for m, v in zip(args['mesh'], y)])
self.models['encoder'].train()
self.models['decoder'].train()
cameras = self._randomize_camera(num_samples)
gt_renders = self._render_batch(gts, **cameras)
pred_renders = self._render_batch(recons, **cameras)
sample_dict = {
'gt_base_color': {'value': gt_renders['base_color'] * 2 - 1, 'type': 'image'},
'pred_base_color': {'value': pred_renders['base_color'] * 2 - 1, 'type': 'image'},
'gt_mra': {'value': torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'},
'pred_mra': {'value': torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'},
}
return sample_dict
from typing import *
import os
import copy
import functools
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import utils3d
from easydict import EasyDict as edict
from ..basic import BasicTrainer
from ...modules import sparse as sp
from ...renderers import MeshRenderer
from ...representations import Mesh
from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler
from ...utils.loss_utils import l1_loss, ssim, lpips
class ShapeVaeTrainer(BasicTrainer):
"""
Trainer for Shape VAE
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
lambda_subdiv (float): Subdivision loss weight.
lambda_intersected (float): Intersected loss weight.
lambda_vertice (float): Vertice loss weight.
lambda_kl (float): KL loss weight.
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
"""
def __init__(
self,
*args,
lambda_subdiv: float = 0.1,
lambda_intersected: float = 0.1,
lambda_vertice: float = 1e-2,
lambda_mask: float = 1,
lambda_depth: float = 10,
lambda_normal: float = 1,
lambda_kl: float = 1e-6,
lambda_ssim: float = 0.2,
lambda_lpips: float = 0.2,
render_resolution: float = 1024,
camera_randomization_config: dict = {
'radius_range': [2, 100],
},
**kwargs
):
super().__init__(*args, **kwargs)
self.lambda_subdiv = lambda_subdiv
self.lambda_intersected = lambda_intersected
self.lambda_mask = lambda_mask
self.lambda_vertice = lambda_vertice
self.lambda_depth = lambda_depth
self.lambda_normal = lambda_normal
self.lambda_kl = lambda_kl
self.lambda_ssim = lambda_ssim
self.lambda_lpips = lambda_lpips
self.camera_randomization_config = camera_randomization_config
self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device)
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = BalancedResumableSampler(
self.dataset,
shuffle=True,
batch_size=self.batch_size_per_gpu,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
def _randomize_camera(self, num_samples: int):
# sample radius and fov
r_min, r_max = self.camera_randomization_config['radius_range']
k_min = 1 / r_max**2
k_max = 1 / r_min**2
ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min
radius = 1 / torch.sqrt(ks)
fov = 2 * torch.arcsin(0.5 / radius)
origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1)
# build camera
extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device))
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
near = [np.random.uniform(r - 1, r) for r in radius.tolist()]
return {
'extrinsics': extrinsics,
'intrinsics': intrinsics,
'near': near,
}
def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List,
return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
Returns:
a dict with
mask : [N x 1 x H x W] tensor of rendered masks
normal : [N x 3 x H x W] tensor of rendered normals
depth : [N x 1 x H x W] tensor of rendered depths
"""
ret = {k : [] for k in return_types}
for i, rep in enumerate(reps):
self.renderer.rendering_options['near'] = near[i]
self.renderer.rendering_options['far'] = near[i] + 2
out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
for k in out_dict:
ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
for k in ret:
ret[k] = torch.stack(ret[k])
return ret
def training_losses(
self,
vertices: sp.SparseTensor,
intersected: sp.SparseTensor,
mesh: List[Mesh],
) -> Tuple[Dict, Dict]:
"""
Compute training losses.
Args:
vertices (SparseTensor): vertices of each active voxel
intersected (SparseTensor): intersected flag of each active voxel
mesh (List[Mesh]): the list of meshes to render
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True)
recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected)
terms = edict(loss = 0.0)
# direct regression
if self.lambda_intersected > 0:
terms["direct/intersected"] = F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float())
terms["loss"] = terms["loss"] + self.lambda_intersected * terms["direct/intersected"]
if self.lambda_vertice > 0:
terms["direct/vertice"] = F.mse_loss(pred_vertice.feats, vertices.feats)
terms["loss"] = terms["loss"] + self.lambda_vertice * terms["direct/vertice"]
# subdivision prediction loss
for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)):
terms[f"bce_sub{i}"] = F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float())
terms["loss"] = terms["loss"] + self.lambda_subdiv * terms[f"bce_sub{i}"]
# rendering loss
cameras = self._randomize_camera(len(mesh))
gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth'])
pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth'])
terms['render/mask'] = l1_loss(pred_renders['mask'], gt_renders['mask'])
terms['render/depth'] = l1_loss(pred_renders['depth'], gt_renders['depth'])
terms['render/normal/l1'] = l1_loss(pred_renders['normal'], gt_renders['normal'])
terms['render/normal/ssim'] = 1 - ssim(pred_renders['normal'], gt_renders['normal'])
terms['render/normal/lpips'] = lpips(pred_renders['normal'], gt_renders['normal'])
terms['loss'] = terms['loss'] + \
self.lambda_mask * terms['render/mask'] + \
self.lambda_depth * terms['render/depth'] + \
self.lambda_normal * (terms['render/normal/l1'] + self.lambda_ssim * terms['render/normal/ssim'] + self.lambda_lpips * terms['render/normal/lpips'])
# KL regularization
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=1,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
gts = []
recons = []
recons2 = []
self.models['encoder'].eval()
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
args = {k: v[:batch] for k, v in data.items()}
args = recursive_to_device(args, self.device)
z = self.models['encoder'](args['vertices'], args['intersected'])
self.models['decoder'].train()
y = self.models['decoder'](z, args['intersected'])[0]
z.clear_spatial_cache()
self.models['decoder'].eval()
y2 = self.models['decoder'](z)
gts.extend(args['mesh'])
recons.extend(y)
recons2.extend(y2)
self.models['encoder'].train()
self.models['decoder'].train()
cameras = self._randomize_camera(num_samples)
gt_renders = self._render_batch(gts, **cameras, return_types=['normal'])
recons_renders = self._render_batch(recons, **cameras, return_types=['normal'])
recons2_renders = self._render_batch(recons2, **cameras, return_types=['normal'])
sample_dict = {
'gt': {'value': gt_renders['normal'], 'type': 'image'},
'rec': {'value': recons_renders['normal'], 'type': 'image'},
'rec2': {'value': recons2_renders['normal'], 'type': 'image'},
}
return sample_dict
from typing import *
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from easydict import EasyDict as edict
from ..basic import BasicTrainer
class SparseStructureVaeTrainer(BasicTrainer):
"""
Trainer for Sparse Structure VAE.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss.
lambda_kl (float): KL divergence loss weight.
"""
def __init__(
self,
*args,
loss_type='bce',
lambda_kl=1e-6,
**kwargs
):
super().__init__(*args, **kwargs)
self.loss_type = loss_type
self.lambda_kl = lambda_kl
def training_losses(
self,
ss: torch.Tensor,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses.
Args:
ss: The [N x 1 x H x W x D] tensor of binary sparse structure.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True)
logits = self.training_models['decoder'](z)
terms = edict(loss = 0.0)
if self.loss_type == 'bce':
terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean')
terms["loss"] = terms["loss"] + terms["bce"]
elif self.loss_type == 'l1':
terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean')
terms["loss"] = terms["loss"] + terms["l1"]
elif self.loss_type == 'dice':
logits = F.sigmoid(logits)
terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
terms["loss"] = terms["loss"] + terms["dice"]
else:
raise ValueError(f'Invalid loss type {self.loss_type}')
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
terms["loss"] = terms["loss"] + self.lamda_kl * terms["kl"]
return terms, {}
@torch.no_grad()
def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False):
super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose)
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
gts = []
recons = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
z = self.models['encoder'](args['ss'].float(), sample_posterior=False)
logits = self.models['decoder'](z)
recon = (logits > 0).long()
gts.append(args['ss'])
recons.append(recon)
sample_dict = {
'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'},
'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'},
}
return sample_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment