Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
from typing import *
import math
import torch
import numpy as np
from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler
import torch.distributed as dist
def recursive_to_device(
data: Any,
device: torch.device,
non_blocking: bool = False,
) -> Any:
"""
Recursively move all tensors in a data structure to a device.
"""
if hasattr(data, "to"):
return data.to(device, non_blocking=non_blocking)
elif isinstance(data, (list, tuple)):
return type(data)(recursive_to_device(d, device, non_blocking) for d in data)
elif isinstance(data, dict):
return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()}
else:
return data
def load_balanced_group_indices(
load: List[int],
num_groups: int,
equal_size: bool = False,
) -> List[List[int]]:
"""
Split indices into groups with balanced load.
"""
if equal_size:
group_size = len(load) // num_groups
indices = np.argsort(load)[::-1]
groups = [[] for _ in range(num_groups)]
group_load = np.zeros(num_groups)
for idx in indices:
min_group_idx = np.argmin(group_load)
groups[min_group_idx].append(idx)
if equal_size and len(groups[min_group_idx]) == group_size:
group_load[min_group_idx] = float('inf')
else:
group_load[min_group_idx] += load[idx]
return groups
def cycle(data_loader: DataLoader) -> Iterator:
while True:
for data in data_loader:
if isinstance(data_loader.sampler, ResumableSampler):
data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined]
yield data
if isinstance(data_loader.sampler, DistributedSampler):
data_loader.sampler.epoch += 1
if isinstance(data_loader.sampler, ResumableSampler):
data_loader.sampler.epoch += 1
data_loader.sampler.idx = 0
class ResumableSampler(Sampler):
"""
Distributed sampler that is resumable.
Args:
dataset: Dataset used for sampling.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
"""
def __init__(
self,
dataset: Dataset,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
self.dataset = dataset
self.epoch = 0
self.idx = 0
self.drop_last = drop_last
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.rank = dist.get_rank() if dist.is_initialized() else 0
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
self.total_size = self.num_samples * self.world_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.world_size]
# resume from previous state
indices = indices[self.idx:]
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def state_dict(self) -> dict[str, int]:
return {
'epoch': self.epoch,
'idx': self.idx,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.idx = state_dict['idx']
class BalancedResumableSampler(ResumableSampler):
"""
Distributed sampler that is resumable and balances the load among the processes.
Args:
dataset: Dataset used for sampling.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
"""
def __init__(
self,
dataset: Dataset,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 1,
) -> None:
assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler'
super().__init__(dataset, shuffle, seed, drop_last)
self.batch_size = batch_size
self.loads = dataset.loads
def __iter__(self) -> Iterator:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# balance load among processes
num_batches = len(indices) // (self.batch_size * self.world_size)
balanced_indices = []
for i in range(num_batches):
start_idx = i * self.batch_size * self.world_size
end_idx = (i + 1) * self.batch_size * self.world_size
batch_indices = indices[start_idx:end_idx]
batch_loads = [self.loads[idx] for idx in batch_indices]
groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True)
balanced_indices.extend([batch_indices[j] for j in groups[self.rank]])
# resume from previous state
indices = balanced_indices[self.idx:]
return iter(indices)
import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def read_file_dist(path):
"""
Read the binary file distributedly.
File is only read once by the rank 0 process and broadcasted to other processes.
Returns:
data (io.BytesIO): The binary data read from the file.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# read file
size = torch.LongTensor(1).cuda()
if dist.get_rank() == 0:
with open(path, 'rb') as f:
data = f.read()
data = torch.ByteTensor(
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
).cuda()
size[0] = data.shape[0]
# broadcast size
dist.broadcast(size, src=0)
if dist.get_rank() != 0:
data = torch.ByteTensor(size[0].item()).cuda()
# broadcast data
dist.broadcast(data, src=0)
# convert to io.BytesIO
data = data.cpu().numpy().tobytes()
data = io.BytesIO(data)
return data
else:
with open(path, 'rb') as f:
data = f.read()
data = io.BytesIO(data)
return data
def unwrap_dist(model):
"""
Unwrap the model from distributed training.
"""
if isinstance(model, DDP):
return model.module
return model
@contextmanager
def master_first():
"""
A context manager that ensures master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
@contextmanager
def local_master_first():
"""
A context manager that ensures local master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() % torch.cuda.device_count() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
\ No newline at end of file
from abc import abstractmethod
from contextlib import contextmanager
from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
class MemoryController:
"""
Base class for memory management during training.
"""
_last_input_size = None
_last_mem_ratio = []
@contextmanager
def record(self):
pass
def update_run_states(self, input_size=None, mem_ratio=None):
if self._last_input_size is None:
self._last_input_size = input_size
elif self._last_input_size!= input_size:
raise ValueError(f'Input size should not change for different ElasticModules.')
self._last_mem_ratio.append(mem_ratio)
@abstractmethod
def get_mem_ratio(self, input_size):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def log(self):
pass
class LinearMemoryController(MemoryController):
"""
A simple controller for memory management during training.
The memory usage is modeled as a linear function of:
- the number of input parameters
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
memory_usage = k * input_size * mem_ratio + b
The controller keeps track of the memory usage and gives the
expected memory ratio to keep the memory usage under a target
"""
def __init__(
self,
buffer_size=1000,
update_every=500,
target_ratio=0.8,
available_memory=None,
max_mem_ratio_start=0.1,
params=None,
device=None
):
self.buffer_size = buffer_size
self.update_every = update_every
self.target_ratio = target_ratio
self.device = device or torch.cuda.current_device()
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
self._memory = np.zeros(buffer_size, dtype=np.float32)
self._input_size = np.zeros(buffer_size, dtype=np.float32)
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
self._buffer_ptr = 0
self._buffer_length = 0
self._params = tuple(params) if params is not None else (0.0, 0.0)
self._max_mem_ratio = max_mem_ratio_start
self.step = 0
def __repr__(self):
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
def _add_sample(self, memory, input_size, mem_ratio):
self._memory[self._buffer_ptr] = memory
self._input_size[self._buffer_ptr] = input_size
self._mem_ratio[self._buffer_ptr] = mem_ratio
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
@contextmanager
def record(self):
torch.cuda.reset_peak_memory_stats(self.device)
self._last_input_size = None
self._last_mem_ratio = []
yield
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
self.step += 1
if self.step % self.update_every == 0:
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
self._fit_params()
def _fit_params(self):
memory_usage = self._memory[:self._buffer_length]
input_size = self._input_size[:self._buffer_length]
mem_ratio = self._mem_ratio[:self._buffer_length]
x = input_size * mem_ratio
y = memory_usage
k, b = np.polyfit(x, y, 1)
self._params = (k, b)
# self._visualize()
def _visualize(self):
import matplotlib.pyplot as plt
memory_usage = self._memory[:self._buffer_length]
input_size = self._input_size[:self._buffer_length]
mem_ratio = self._mem_ratio[:self._buffer_length]
k, b = self._params
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
x = np.array([0.0, 20000.0])
plt.plot(x, k * x + b, c='r')
plt.savefig(f'linear_memory_controller_{self.step}.png')
plt.cla()
def get_mem_ratio(self, input_size):
k, b = self._params
if k == 0: return np.random.rand() * self._max_mem_ratio
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
return min(self._max_mem_ratio, max(0.0, pred))
def state_dict(self):
return {
'params': self._params,
}
def load_state_dict(self, state_dict):
self._params = tuple(state_dict['params'])
def log(self):
return {
'params/k': self._params[0],
'params/b': self._params[1],
'memory': self._last_memory,
'input_size': self._last_input_size,
'mem_ratio': self._last_mem_ratio,
}
class ElasticModule(nn.Module):
"""
Module for training with elastic memory management.
"""
def __init__(self):
super().__init__()
self._memory_controller: MemoryController = None
@abstractmethod
def _get_input_size(self, *args, **kwargs) -> int:
"""
Get the size of the input data.
Returns:
int: The size of the input data.
"""
pass
@abstractmethod
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
"""
Forward with a given memory ratio.
"""
pass
def register_memory_controller(self, memory_controller: MemoryController):
self._memory_controller = memory_controller
def forward(self, *args, **kwargs):
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
else:
input_size = self._get_input_size(*args, **kwargs)
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
self._memory_controller.update_run_states(input_size, mem_ratio)
return ret
class ElasticModuleMixin:
"""
Mixin for training with elastic memory management.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._memory_controller: MemoryController = None
@abstractmethod
def _get_input_size(self, *args, **kwargs) -> int:
"""
Get the size of the input data.
Returns:
int: The size of the input data.
"""
pass
@abstractmethod
@contextmanager
def with_mem_ratio(self, mem_ratio=1.0) -> float:
"""
Context manager for training with a reduced memory ratio compared to the full memory usage.
Returns:
float: The exact memory ratio used during the forward pass.
"""
pass
def register_memory_controller(self, memory_controller: MemoryController):
self._memory_controller = memory_controller
def forward(self, *args, **kwargs):
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
ret = super().forward(*args, **kwargs)
else:
input_size = self._get_input_size(*args, **kwargs)
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
ret = super().forward(*args, **kwargs)
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
return ret
import re
import numpy as np
import cv2
import torch
import contextlib
# Dictionary utils
def _dict_merge(dicta, dictb, prefix=''):
"""
Merge two dictionaries.
"""
assert isinstance(dicta, dict), 'input must be a dictionary'
assert isinstance(dictb, dict), 'input must be a dictionary'
dict_ = {}
all_keys = set(dicta.keys()).union(set(dictb.keys()))
for key in all_keys:
if key in dicta.keys() and key in dictb.keys():
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
else:
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
elif key in dicta.keys():
dict_[key] = dicta[key]
else:
dict_[key] = dictb[key]
return dict_
def dict_merge(dicta, dictb):
"""
Merge two dictionaries.
"""
return _dict_merge(dicta, dictb, prefix='')
def dict_foreach(dic, func, special_func={}):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
dic[key] = dict_foreach(dic[key], func)
else:
if key in special_func.keys():
dic[key] = special_func[key](dic[key])
else:
dic[key] = func(dic[key])
return dic
def dict_reduce(dicts, func, special_func={}):
"""
Reduce a list of dictionaries. Leaf values must be scalars.
"""
assert isinstance(dicts, list), 'input must be a list of dictionaries'
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
reduced_dict = {}
for key in all_keys:
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
if isinstance(vlist[0], dict):
reduced_dict[key] = dict_reduce(vlist, func, special_func)
else:
if key in special_func.keys():
reduced_dict[key] = special_func[key](vlist)
else:
reduced_dict[key] = func(vlist)
return reduced_dict
def dict_any(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if dict_any(dic[key], func):
return True
else:
if func(dic[key]):
return True
return False
def dict_all(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if not dict_all(dic[key], func):
return False
else:
if not func(dic[key]):
return False
return True
def dict_flatten(dic, sep='.'):
"""
Flatten a nested dictionary into a dictionary with no nested dictionaries.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
flat_dict = {}
for key in dic.keys():
if isinstance(dic[key], dict):
sub_dict = dict_flatten(dic[key], sep=sep)
for sub_key in sub_dict.keys():
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
else:
flat_dict[key] = dic[key]
return flat_dict
# Context utils
@contextlib.contextmanager
def nested_contexts(*contexts):
with contextlib.ExitStack() as stack:
for ctx in contexts:
stack.enter_context(ctx())
yield
# Image utils
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
num_images = len(images)
if nrow is None and ncol is None:
if aspect_ratio is not None:
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
else:
nrow = int(np.sqrt(num_images))
ncol = (num_images + nrow - 1) // nrow
elif nrow is None and ncol is not None:
nrow = (num_images + ncol - 1) // ncol
elif nrow is not None and ncol is None:
ncol = (num_images + nrow - 1) // nrow
else:
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
if images[0].ndim == 2:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
else:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
for i, img in enumerate(images):
row = i // ncol
col = i % ncol
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
return grid
def notes_on_image(img, notes=None):
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if notes is not None:
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"):
"""
Draw text on an image of the given resolution. The text is automatically wrapped
and scaled so that it fits completely within the image while preserving any explicit
line breaks and original spacing. Horizontal and vertical alignment can be controlled
via flags.
Parameters:
text (str): The input text. Newline characters and spacing are preserved.
resolution (tuple): The image resolution as (width, height).
max_size (float): The maximum font size.
h_align (str): Horizontal alignment. Options: "left", "center", "right".
v_align (str): Vertical alignment. Options: "top", "center", "bottom".
Returns:
numpy.ndarray: The resulting image (BGR format) with the text drawn.
"""
width, height = resolution
# Create a white background image
img = np.full((height, width, 3), 255, dtype=np.uint8)
# Set margins and compute available drawing area
margin = 10
avail_width = width - 2 * margin
avail_height = height - 2 * margin
# Choose OpenCV font and text thickness
font = cv2.FONT_HERSHEY_SIMPLEX
thickness = 1
# Ratio for additional spacing between lines (relative to the height of "A")
line_spacing_ratio = 0.5
def wrap_line(line, max_width, font, thickness, scale):
"""
Wrap a single line of text into multiple lines such that each line's
width (measured at the given scale) does not exceed max_width.
This function preserves the original spacing by splitting the line into tokens
(words and whitespace) using a regular expression.
Parameters:
line (str): The input text line.
max_width (int): Maximum allowed width in pixels.
font (int): OpenCV font identifier.
thickness (int): Text thickness.
scale (float): The current font scale.
Returns:
List[str]: A list of wrapped lines.
"""
# Split the line into tokens (words and whitespace), preserving spacing
tokens = re.split(r'(\s+)', line)
if not tokens:
return ['']
wrapped_lines = []
current_line = ""
for token in tokens:
candidate = current_line + token
candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0]
if candidate_width <= max_width:
current_line = candidate
else:
# If current_line is empty, the token itself is too wide;
# break the token character by character.
if current_line == "":
sub_token = ""
for char in token:
candidate_char = sub_token + char
if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width:
sub_token = candidate_char
else:
if sub_token:
wrapped_lines.append(sub_token)
sub_token = char
current_line = sub_token
else:
wrapped_lines.append(current_line)
current_line = token
if current_line:
wrapped_lines.append(current_line)
return wrapped_lines
def compute_text_block(scale):
"""
Wrap the entire text (splitting at explicit newline characters) using the
provided scale, and then compute the overall width and height of the text block.
Returns:
wrapped_lines (List[str]): The list of wrapped lines.
block_width (int): Maximum width among the wrapped lines.
block_height (int): Total height of the text block including spacing.
sizes (List[tuple]): A list of (width, height) for each wrapped line.
spacing (int): The spacing between lines (computed from the scaled "A" height).
"""
# Split text by explicit newlines
input_lines = text.splitlines() if text else ['']
wrapped_lines = []
for line in input_lines:
wrapped = wrap_line(line, avail_width, font, thickness, scale)
wrapped_lines.extend(wrapped)
sizes = []
for line in wrapped_lines:
(text_size, _) = cv2.getTextSize(line, font, scale, thickness)
sizes.append(text_size) # (width, height)
block_width = max((w for w, h in sizes), default=0)
# Use the height of "A" (at the current scale) to compute line spacing
base_height = cv2.getTextSize("A", font, scale, thickness)[0][1]
spacing = int(line_spacing_ratio * base_height)
block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0
return wrapped_lines, block_width, block_height, sizes, spacing
# Use binary search to find the maximum scale that allows the text block to fit
lo = 0.001
hi = max_size
eps = 0.001 # convergence threshold
best_scale = lo
best_result = None
while hi - lo > eps:
mid = (lo + hi) / 2
wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid)
# Ensure that both width and height constraints are met
if block_width <= avail_width and block_height <= avail_height:
best_scale = mid
best_result = (wrapped_lines, block_width, block_height, sizes, spacing)
lo = mid # try a larger scale
else:
hi = mid # reduce the scale
if best_result is None:
best_scale = 0.5
best_result = compute_text_block(best_scale)
wrapped_lines, block_width, block_height, sizes, spacing = best_result
# Compute starting y-coordinate based on vertical alignment flag
if v_align == "top":
y_top = margin
elif v_align == "center":
y_top = margin + (avail_height - block_height) // 2
elif v_align == "bottom":
y_top = margin + (avail_height - block_height)
else:
y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag
# For cv2.putText, the y coordinate represents the text baseline;
# so for the first line add its height.
y = y_top + (sizes[0][1] if sizes else 0)
# Draw each line with horizontal alignment based on the flag
for i, line in enumerate(wrapped_lines):
line_width, line_height = sizes[i]
if h_align == "left":
x = margin
elif h_align == "center":
x = margin + (avail_width - line_width) // 2
elif h_align == "right":
x = margin + (avail_width - line_width)
else:
x = margin # default to left if invalid flag
cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA)
y += line_height + spacing
return img
def save_image_with_notes(img, path, notes=None):
"""
Save an image with notes.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy().transpose(1, 2, 0)
if img.dtype == np.float32 or img.dtype == np.float64:
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = notes_on_image(img, notes)
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# debug utils
def atol(x, y):
"""
Absolute tolerance.
"""
return torch.abs(x - y)
def rtol(x, y):
"""
Relative tolerance.
"""
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
# print utils
def indent(s, n=4):
"""
Indent a string.
"""
lines = s.split('\n')
for i in range(1, len(lines)):
lines[i] = ' ' * n + lines[i]
return '\n'.join(lines)
from typing import *
import torch
import numpy as np
import torch.utils
class AdaptiveGradClipper:
"""
Adaptive gradient clipping for training.
"""
def __init__(
self,
max_norm=None,
clip_percentile=95.0,
buffer_size=1000,
):
self.max_norm = max_norm
self.clip_percentile = clip_percentile
self.buffer_size = buffer_size
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
self._max_norm = max_norm
self._buffer_ptr = 0
self._buffer_length = 0
def __repr__(self):
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
def state_dict(self):
return {
'grad_norm': self._grad_norm,
'max_norm': self._max_norm,
'buffer_ptr': self._buffer_ptr,
'buffer_length': self._buffer_length,
}
def load_state_dict(self, state_dict):
self._grad_norm = state_dict['grad_norm']
self._max_norm = state_dict['max_norm']
self._buffer_ptr = state_dict['buffer_ptr']
self._buffer_length = state_dict['buffer_length']
def log(self):
return {
'max_norm': self._max_norm,
}
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
"""Clip the gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
max_norm = self._max_norm if self._max_norm is not None else float('inf')
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
if torch.isfinite(grad_norm):
self._grad_norm[self._buffer_ptr] = grad_norm
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
if self._buffer_length == self.buffer_size:
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
return grad_norm
\ No newline at end of file
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
from lpips import LPIPS
def smooth_l1_loss(pred, target, beta=1.0):
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
return loss.mean()
def l1_loss(network_output, gt):
return torch.abs((network_output - gt)).mean()
def l2_loss(network_output, gt):
return ((network_output - gt) ** 2).mean()
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def psnr(img1, img2, max_val=1.0):
mse = F.mse_loss(img1, img2)
return 20 * torch.log10(max_val / torch.sqrt(mse))
def ssim(img1, img2, window_size=11, size_average=True):
channel = img1.size(-3)
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
loss_fn_vgg = None
def lpips(img1, img2, value_range=(0, 1)):
global loss_fn_vgg
if loss_fn_vgg is None:
loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
# normalize to [-1, 1]
img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
return loss_fn_vgg(img1, img2).mean()
def normal_angle(pred, gt):
pred = pred * 2.0 - 1.0
gt = gt * 2.0 - 1.0
norms = pred.norm(dim=-1) * gt.norm(dim=-1)
cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
if ang.isnan():
return -1
return ang
from typing import Tuple, Dict
import numpy as np
from trimesh import grouping, util, remesh
import struct
import re
from plyfile import PlyData, PlyElement
def read_ply(filename):
"""
Read a PLY file and return vertices, triangle faces, and quad faces.
Args:
filename (str): The file path to read from.
Returns:
vertices (np.ndarray): Array of shape [N, 3] containing vertex positions.
tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none).
quads (np.ndarray): Array of shape [K, 4] containing quad face indices (empty if none).
"""
with open(filename, 'rb') as f:
# Read the header until 'end_header' is encountered
header_bytes = b""
while True:
line = f.readline()
if not line:
raise ValueError("PLY header not found")
header_bytes += line
if b"end_header" in line:
break
header = header_bytes.decode('utf-8')
# Determine if the file is in ASCII or binary format
is_ascii = "ascii" in header
# Extract the number of vertices and faces from the header using regex
vertex_match = re.search(r'element vertex (\d+)', header)
if vertex_match:
num_vertices = int(vertex_match.group(1))
else:
raise ValueError("Vertex count not found in header")
face_match = re.search(r'element face (\d+)', header)
if face_match:
num_faces = int(face_match.group(1))
else:
raise ValueError("Face count not found in header")
vertices = []
tris = []
quads = []
if is_ascii:
# For ASCII format, read each line of vertex data (each line contains 3 floats)
for _ in range(num_vertices):
line = f.readline().decode('utf-8').strip()
if not line:
continue
parts = line.split()
vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
# Read face data, where the first number indicates the number of vertices for the face
for _ in range(num_faces):
line = f.readline().decode('utf-8').strip()
if not line:
continue
parts = line.split()
count = int(parts[0])
indices = list(map(int, parts[1:]))
if count == 3:
tris.append(indices)
elif count == 4:
quads.append(indices)
else:
# Skip faces with other numbers of vertices (can be extended as needed)
pass
else:
# For binary format: read directly from the binary stream
# Each vertex consists of 3 floats (12 bytes per vertex)
for _ in range(num_vertices):
data = f.read(12)
if len(data) < 12:
raise ValueError("Insufficient vertex data")
v = struct.unpack('<fff', data)
vertices.append(v)
# Read face data from the binary stream
for _ in range(num_faces):
# First, read 1 byte indicating the number of vertices in the face
count_data = f.read(1)
if len(count_data) < 1:
raise ValueError("Failed to read face vertex count")
count = struct.unpack('<B', count_data)[0]
if count == 3:
data = f.read(12) # 3 * 4 bytes
if len(data) < 12:
raise ValueError("Insufficient data for triangle face")
indices = struct.unpack('<3i', data)
tris.append(indices)
elif count == 4:
data = f.read(16) # 4 * 4 bytes
if len(data) < 16:
raise ValueError("Insufficient data for quad face")
indices = struct.unpack('<4i', data)
quads.append(indices)
else:
# For faces with a different number of vertices, read count*4 bytes
data = f.read(count * 4)
# Skip or extend processing as needed
raise ValueError(f"Unsupported face with {count} vertices")
# Convert lists to torch.Tensor
vertices = np.array(vertices, dtype=np.float32)
tris = np.array(tris, dtype=np.int32) if len(tris) > 0 else np.empty((0, 3), dtype=np.int32)
quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32)
return vertices, tris, quads
def write_ply(
filename: str,
vertices: np.ndarray,
tris: np.ndarray,
quads: np.ndarray,
vertex_colors: np.ndarray = None,
ascii: bool = False
):
"""
Write a mesh to a PLY file, with the option to save in ASCII or binary format,
and optional per-vertex colors.
Args:
filename (str): The filename to write to.
vertices (np.ndarray): [N, 3] The vertex positions.
tris (np.ndarray): [M, 3] The triangle indices.
quads (np.ndarray): [K, 4] The quad indices.
vertex_colors (np.ndarray, optional): [N, 3] or [N, 4] UInt8 colors for each vertex (RGB or RGBA).
ascii (bool): If True, write in ASCII format; otherwise binary little-endian.
"""
import struct
num_vertices = len(vertices)
num_faces = len(tris) + len(quads)
# Build header
header_lines = [
"ply",
f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}",
f"element vertex {num_vertices}",
"property float x",
"property float y",
"property float z",
]
# Add vertex color properties if provided
has_color = vertex_colors is not None
if has_color:
# Expect uint8 values 0-255
header_lines += [
"property uchar red",
"property uchar green",
"property uchar blue",
]
# Include alpha if RGBA
if vertex_colors.shape[1] == 4:
header_lines.append("property uchar alpha")
header_lines += [
f"element face {num_faces}",
"property list uchar int vertex_index",
"end_header",
""
]
header = "\n".join(header_lines)
mode = 'w' if ascii else 'wb'
with open(filename, mode) as f:
# Write header
if ascii:
f.write(header)
else:
f.write(header.encode('utf-8'))
# Write vertex data
for i, v in enumerate(vertices):
if ascii:
line = f"{v[0]} {v[1]} {v[2]}"
if has_color:
col = vertex_colors[i]
line += ' ' + ' '.join(str(int(c)) for c in col)
f.write(line + '\n')
else:
# pack xyz as floats
f.write(struct.pack('<fff', *v))
if has_color:
col = vertex_colors[i]
# pack as uchar
if col.shape[0] == 3:
f.write(struct.pack('<BBB', *col))
else:
f.write(struct.pack('<BBBB', *col))
# Write face data
if ascii:
for tri in tris:
f.write(f"3 {tri[0]} {tri[1]} {tri[2]}\n")
for quad in quads:
f.write(f"4 {quad[0]} {quad[1]} {quad[2]} {quad[3]}\n")
else:
for tri in tris:
f.write(struct.pack('<B3i', 3, *tri))
for quad in quads:
f.write(struct.pack('<B4i', 4, *quad))
def write_pbr_ply(
filename: str,
vertices: np.ndarray,
faces: np.ndarray,
base_color: np.ndarray,
metallic: np.ndarray,
roughness: np.ndarray,
alpha: np.ndarray,
ascii: bool = False
):
"""
Write a mesh to a PLY file, with the option to save in ASCII or binary format,
and optional per-vertex colors.
Args:
filename (str): The filename to write to.
vertices (np.ndarray): [N, 3] The vertex positions.
faces (np.ndarray): [M, 3] The triangle indices.
base_color (np.ndarray): [N, 3] UInt8 colors for each vertex (RGB).
metallic (np.ndarray): [N] UInt8 values for metallicness.
roughness (np.ndarray): [N] UInt8 values for roughness.
alpha (np.ndarray): [N] UInt8 values for alpha.
ascii (bool): If True, write in ASCII format; otherwise binary little-endian.
"""
vertex_dtype = [
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1'),
('metallic', 'u1'), ('roughness', 'u1'), ('alpha', 'u1')
]
vertex_data = np.empty(len(vertices), dtype=vertex_dtype)
vertex_data['x'] = vertices[:, 0]
vertex_data['y'] = vertices[:, 1]
vertex_data['z'] = vertices[:, 2]
vertex_data['red'] = base_color[:, 0]
vertex_data['green'] = base_color[:, 1]
vertex_data['blue'] = base_color[:, 2]
vertex_data['metallic'] = metallic
vertex_data['roughness'] = roughness
vertex_data['alpha'] = alpha
face_dtype = [
('vertex_indices', 'i4', (3,))
]
face_data = np.empty(len(faces), dtype=face_dtype)
face_data['vertex_indices'] = faces
ply_data = PlyData([
PlyElement.describe(vertex_data,'vertex'),
PlyElement.describe(face_data, 'face'),
], text=ascii)
ply_data.write(filename)
"""
Pipeline logging utility for Trellis 2.
Writes to both /tmp/trellis2_pipeline.log (full DEBUG) and stdout (INFO+).
Call reset_log() at the start of each new generation run.
"""
import sys
import time
import logging
import traceback
from datetime import datetime
from typing import Optional
import torch
import numpy as np
LOG_PATH = "/tmp/trellis2_pipeline.log"
_logger: Optional[logging.Logger] = None
_run_start: float = 0.0
_debug_enabled: bool = False
def _make_logger() -> logging.Logger:
log = logging.getLogger("trellis2_pipeline")
log.setLevel(logging.DEBUG)
log.handlers.clear()
log.propagate = False
fmt = logging.Formatter(
"%(asctime)s.%(msecs)03d %(levelname)-5s %(message)s",
datefmt="%H:%M:%S",
)
fh = logging.FileHandler(LOG_PATH, mode="w", encoding="utf-8")
fh.setLevel(logging.DEBUG)
fh.setFormatter(fmt)
log.addHandler(fh)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG if _debug_enabled else logging.INFO)
ch.setFormatter(fmt)
log.addHandler(ch)
return log
def reset_log(label: str = "") -> None:
"""Call at the very start of every new generation request."""
global _logger, _run_start
_logger = _make_logger()
_run_start = time.perf_counter()
_logger.info("=" * 80)
_logger.info(f"NEW PIPELINE RUN {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} {label}")
_logger.info("=" * 80)
def get_logger() -> logging.Logger:
global _logger
if _logger is None:
_logger = _make_logger()
return _logger
def set_debug(enabled: bool) -> None:
"""Enable or disable DEBUG-level output to stdout."""
global _debug_enabled, _logger
_debug_enabled = enabled
# Update any already-created logger's stdout handler
if _logger is not None:
for handler in _logger.handlers:
if isinstance(handler, logging.StreamHandler) and not isinstance(handler, logging.FileHandler):
handler.setLevel(logging.DEBUG if enabled else logging.INFO)
def elapsed() -> str:
return f"+{time.perf_counter() - _run_start:.2f}s"
# ── Tensor helpers ────────────────────────────────────────────────────────────
def _ts(t) -> str:
"""One-line tensor summary."""
if t is None:
return "None"
if not isinstance(t, torch.Tensor):
return f"<{type(t).__name__}>"
try:
f = t.detach().float()
has_nan = torch.isnan(f).any().item()
has_inf = torch.isinf(f).any().item()
if has_nan or has_inf:
flags = ("NaN " if has_nan else "") + ("inf" if has_inf else "")
return f"shape={list(t.shape)} dtype={t.dtype}{flags.strip()}"
mn, mx = f.min().item(), f.max().item()
mu = f.mean().item()
return (f"shape={list(t.shape)} dtype={t.dtype} "
f"min={mn:.4g} max={mx:.4g} mean={mu:.4g}")
except Exception as e:
return f"shape={list(t.shape)} dtype={t.dtype} [stats error: {e}]"
def log_tensor(t, name: str, level: str = "info") -> None:
getattr(get_logger(), level)(f" {elapsed()} [{name}] {_ts(t)}")
def log_mesh(vertices, faces, tag: str = "mesh") -> None:
L = get_logger()
prefix = f" {elapsed()} [MESH:{tag}]"
if vertices is None or faces is None:
L.warning(f"{prefix} vertices or faces is None!")
return
try:
v = vertices.detach().float() if isinstance(vertices, torch.Tensor) else torch.tensor(vertices, dtype=torch.float32)
f = faces.detach() if isinstance(faces, torch.Tensor) else torch.tensor(faces)
has_nan = torch.isnan(v).any().item()
has_inf = torch.isinf(v).any().item()
ok = "✅" if not has_nan and not has_inf else "❌"
L.info(f"{prefix} {ok} "
f"vertices={list(v.shape)} faces={list(f.shape)} "
f"pos=[{v.min().item():.4g}, {v.max().item():.4g}] "
f"NaN={has_nan} inf={has_inf}")
if f.numel() > 0:
idx_min = int(f.min().item())
idx_max = int(f.max().item())
n_verts = v.shape[0]
valid = (idx_min >= 0) and (idx_max < n_verts)
flag = "✅" if valid else "❌ OUT-OF-BOUNDS"
L.info(f"{prefix} face-idx range=[{idx_min}, {idx_max}] "
f"num_vertices={n_verts} {flag}")
if not valid:
L.error(f"{prefix} ⚠ INVALID FACE INDICES — expect corruption downstream!")
if v.shape[0] >= 3:
L.debug(f"{prefix} first 3 verts: {v[:3].tolist()}")
if f.shape[0] >= 3:
L.debug(f"{prefix} first 3 faces: {f[:3].tolist()}")
except Exception as e:
L.error(f"{prefix} exception: {e}\n{traceback.format_exc()}")
def log_uv(uv, tag: str = "uv") -> None:
L = get_logger()
prefix = f" {elapsed()} [UV:{tag}]"
if uv is None:
L.warning(f"{prefix} None!")
return
try:
t = uv.detach().float() if isinstance(uv, torch.Tensor) else torch.tensor(uv, dtype=torch.float32)
has_nan = torch.isnan(t).any().item()
has_inf = torch.isinf(t).any().item()
ok = "✅" if not has_nan and not has_inf else "❌"
u_range = f"[{t[:, 0].min().item():.4g}, {t[:, 0].max().item():.4g}]"
v_range = f"[{t[:, 1].min().item():.4g}, {t[:, 1].max().item():.4g}]"
n_zero = (t.abs().sum(dim=-1) < 1e-8).sum().item()
pct_zero = 100.0 * n_zero / max(1, t.shape[0])
L.info(f"{prefix} {ok} shape={list(t.shape)} "
f"U={u_range} V={v_range} "
f"zeros={n_zero}/{t.shape[0]} ({pct_zero:.1f}%) "
f"NaN={has_nan} inf={has_inf}")
if pct_zero > 50:
L.error(f"{prefix} ⚠ >50% UV coordinates are zero — likely UV gen failure!")
if t.shape[0] >= 3:
L.debug(f"{prefix} first 5 UVs: {t[:5].tolist()}")
except Exception as e:
L.error(f"{prefix} exception: {e}")
def log_sparse(sp_tensor, tag: str = "sparse") -> None:
"""Log a SparseTensor or VarLenTensor."""
L = get_logger()
prefix = f" {elapsed()} [SPARSE:{tag}]"
try:
feats = sp_tensor.feats if hasattr(sp_tensor, "feats") else None
coords = sp_tensor.coords if hasattr(sp_tensor, "coords") else None
if feats is not None:
L.info(f"{prefix} feats: {_ts(feats)}")
if coords is not None:
L.info(f"{prefix} coords: {_ts(coords)} "
f"max={coords.max(dim=0).values.tolist() if coords.numel() > 0 else 'N/A'}")
except Exception as e:
L.error(f"{prefix} exception: {e}")
def section(title: str) -> None:
L = get_logger()
L.info("")
L.info(f"{'─'*70}")
L.info(f" {elapsed()}{title}")
L.info(f"{'─'*70}")
def check_tensor(t, name: str, expect_finite: bool = True) -> bool:
"""Returns True if tensor passes checks. Logs ERROR if not."""
if t is None:
get_logger().warning(f" {elapsed()} [{name}] is None")
return False
if not isinstance(t, torch.Tensor):
return True
f = t.detach().float()
has_nan = torch.isnan(f).any().item()
has_inf = torch.isinf(f).any().item()
if expect_finite and (has_nan or has_inf):
get_logger().error(f" {elapsed()} [{name}] ❌ CORRUPT — NaN={has_nan} inf={has_inf} {_ts(t)}")
return False
return True
import numpy as np
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
u, v = hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
if remap:
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]
\ No newline at end of file
import torch
import numpy as np
from tqdm import tqdm
import utils3d
from PIL import Image
from ..renderers import MeshRenderer, VoxelRenderer, PbrMeshRenderer
from ..representations import Mesh, Voxel, MeshWithPbrMaterial, MeshWithVoxel
from .random_utils import sphere_hammersley_sequence
def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
is_list = isinstance(yaws, list)
if not is_list:
yaws = [yaws]
pitchs = [pitchs]
if not isinstance(rs, list):
rs = [rs] * len(yaws)
if not isinstance(fovs, list):
fovs = [fovs] * len(yaws)
extrinsics = []
intrinsics = []
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
yaw = torch.tensor(float(yaw)).cuda()
pitch = torch.tensor(float(pitch)).cuda()
orig = torch.tensor([
torch.sin(yaw) * torch.cos(pitch),
torch.cos(yaw) * torch.cos(pitch),
torch.sin(pitch),
]).cuda() * r
extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
extrinsics.append(extr)
intrinsics.append(intr)
if not is_list:
extrinsics = extrinsics[0]
intrinsics = intrinsics[0]
return extrinsics, intrinsics
def _safe_ssaa(sample, requested_ssaa, resolution, vram_limit_gb=14.0):
"""
Cap ssaa so the estimated peak VRAM stays under vram_limit_gb.
Rough model: raster buffers at (resolution*ssaa)^2, 3 envmaps, 8 peel layers.
Each peel layer: ~160 MB transient (xyz + img + rast).
Constant mesh overhead: ~400 MB.
"""
num_faces = 0
if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)):
num_faces = sample.faces.shape[0] if hasattr(sample, 'faces') else 0
for ssaa in [requested_ssaa, requested_ssaa - 1, 1]:
if ssaa < 1:
ssaa = 1
pixels = (resolution * ssaa) ** 2
# ~160 MB per peel layer (3 envmaps * shaded + rast + xyz + img)
peel_layers = 8
est_mb = (pixels * 4 * 4 * (3 + peel_layers) / 1e6) + 400
if est_mb < vram_limit_gb * 1024:
return ssaa
return 1
def get_renderer(sample, **kwargs):
if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)):
renderer = PbrMeshRenderer()
resolution = kwargs.get('resolution', 512)
requested_ssaa = kwargs.get('ssaa', 1)
ssaa = _safe_ssaa(sample, requested_ssaa, resolution)
if ssaa != requested_ssaa:
import logging
logging.getLogger(__name__).warning(
f"[render_utils] ssaa capped {requested_ssaa}{ssaa} to stay under VRAM limit"
)
renderer.rendering_options.resolution = resolution
renderer.rendering_options.near = kwargs.get('near', 1)
renderer.rendering_options.far = kwargs.get('far', 100)
renderer.rendering_options.ssaa = ssaa
renderer.rendering_options.peel_layers = kwargs.get('peel_layers', 8)
elif isinstance(sample, Mesh):
renderer = MeshRenderer()
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
renderer.rendering_options.near = kwargs.get('near', 1)
renderer.rendering_options.far = kwargs.get('far', 100)
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
renderer.rendering_options.chunk_size = kwargs.get('chunk_size', None)
elif isinstance(sample, Voxel):
renderer = VoxelRenderer()
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
renderer.rendering_options.near = kwargs.get('near', 0.1)
renderer.rendering_options.far = kwargs.get('far', 10.0)
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
else:
raise ValueError(f'Unsupported sample type: {type(sample)}')
return renderer
def render_frames(sample, extrinsics, intrinsics, options={}, verbose=True, **kwargs):
renderer = get_renderer(sample, **options)
# Free stale GPU allocations from the generation phase before rendering starts.
# On ROCm, driver-level OOM causes a display freeze rather than a Python exception,
# so we clear proactively rather than waiting for the allocator to evict.
torch.cuda.empty_cache()
rets = {}
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), total=len(extrinsics), desc='Rendering', disable=not verbose):
res = renderer.render(sample, extr, intr, **kwargs)
for k, v in res.items():
if k not in rets: rets[k] = []
if v.dim() == 2: v = v[None].repeat(3, 1, 1)
rets[k].append(np.clip(v.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
return rets
def render_video(sample, resolution=1024, bg_color=(0, 0, 0), num_frames=120, r=2, fov=40, **kwargs):
yaws = -torch.linspace(0, 2 * 3.1415, num_frames) + np.pi/2
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
yaws = yaws.tolist()
pitch = pitch.tolist()
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
def render_multiview(sample, resolution=512, nviews=30):
r = 2
fov = 40
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
yaws = [cam[0] for cam in cams]
pitchs = [cam[1] for cam in cams]
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
return res['color'], extrinsics, intrinsics
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, nviews=4, **kwargs):
yaw = np.linspace(0, 2 * np.pi, nviews, endpoint=False)
yaw_offset = offset[0]
yaw = [y + yaw_offset for y in yaw]
pitch = [offset[1] for _ in range(nviews)]
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
def make_pbr_vis_frames(result, resolution=1024):
num_frames = len(result['shaded'])
frames = []
for i in range(num_frames):
shaded = Image.fromarray(result['shaded'][i])
normal = Image.fromarray(result['normal'][i])
base_color = Image.fromarray(result['base_color'][i])
metallic = Image.fromarray(result['metallic'][i])
roughness = Image.fromarray(result['roughness'][i])
alpha = Image.fromarray(result['alpha'][i])
shaded = shaded.resize((resolution, resolution))
normal = normal.resize((resolution, resolution))
base_color = base_color.resize((resolution//2, resolution//2))
metallic = metallic.resize((resolution//2, resolution//2))
roughness = roughness.resize((resolution//2, resolution//2))
alpha = alpha.resize((resolution//2, resolution//2))
row1 = np.concatenate([shaded, normal], axis=1)
row2 = np.concatenate([base_color, metallic, roughness, alpha], axis=1)
frame = np.concatenate([row1, row2], axis=0)
frames.append(frame)
return frames
from typing import *
import numpy as np
import torch
from ..modules import sparse as sp
from ..representations import Voxel
from .render_utils import render_video
def pca_color(feats: torch.Tensor, channels: Tuple[int, int, int] = (0, 1, 2)) -> torch.Tensor:
"""
Apply PCA to the features and return the first three principal components.
"""
feats = feats.detach()
u, s, v = torch.svd(feats)
color = u[:, channels]
color = (color - color.min(dim=0, keepdim=True)[0]) / (color.max(dim=0, keepdim=True)[0] - color.min(dim=0, keepdim=True)[0])
return color
def vis_sparse_tensor(
x: sp.SparseTensor,
num_frames: int = 300,
):
assert x.shape[0] == 1, "Only support batch size 1"
assert x.coords.shape[1] == 4, "Only support 3D coordinates"
coords = x.coords.cuda().detach()[:, 1:]
feats = x.feats.cuda().detach()
color = pca_color(feats)
resolution = max(list(x.spatial_shape))
resolution = int(2**np.ceil(np.log2(resolution)))
rep = Voxel(
origin=[-0.5, -0.5, -0.5],
voxel_size=1/resolution,
coords=coords,
attrs=color,
layout={
'color': slice(0, 3),
}
)
return render_video(rep, colors_overwrite=color, num_frames=num_frames)['color']
icon.png

55.4 KB

# 模型唯一标识
modelCode=15305
# 模型名称
modelName=TRELLIS.2
# 模型描述
modelDescription=TRELLIS.2是一款最先进的大型3D生成模型(40亿参数),专为高保真图像到3D生成而设计。
# 运行过程
processType=推理
# 算法类别
appCategory=3D生成
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=BW1000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment