Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
import argparse, sys, os, math, re, glob
from typing import *
import bpy
from mathutils import Vector, Matrix
import numpy as np
import json
import glob
"""=============== BLENDER ==============="""
IMPORT_FUNCTIONS: Dict[str, Callable] = {
"obj": bpy.ops.import_scene.obj,
"glb": bpy.ops.import_scene.gltf,
"gltf": bpy.ops.import_scene.gltf,
"usd": bpy.ops.import_scene.usd,
"fbx": bpy.ops.import_scene.fbx,
"stl": bpy.ops.import_mesh.stl,
"usda": bpy.ops.import_scene.usda,
"dae": bpy.ops.wm.collada_import,
"ply": bpy.ops.import_mesh.ply,
"abc": bpy.ops.wm.alembic_import,
"blend": bpy.ops.wm.append,
}
EXT = {
'PNG': 'png',
'JPEG': 'jpg',
'OPEN_EXR': 'exr',
'TIFF': 'tiff',
'BMP': 'bmp',
'HDR': 'hdr',
'TARGA': 'tga'
}
def init_render(engine='CYCLES', resolution=512):
bpy.context.scene.render.engine = engine
bpy.context.scene.render.resolution_x = resolution
bpy.context.scene.render.resolution_y = resolution
bpy.context.scene.render.resolution_percentage = 100
bpy.context.scene.render.image_settings.file_format = 'PNG'
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
bpy.context.scene.render.film_transparent = True
bpy.context.scene.cycles.device = 'GPU'
bpy.context.scene.cycles.samples = 32
bpy.context.scene.cycles.filter_type = 'BOX'
bpy.context.scene.cycles.filter_width = 1
bpy.context.scene.cycles.diffuse_bounces = 1
bpy.context.scene.cycles.glossy_bounces = 1
bpy.context.scene.cycles.transparent_max_bounces = 3
bpy.context.scene.cycles.transmission_bounces = 3
bpy.context.scene.cycles.use_denoising = True
bpy.context.preferences.addons['cycles'].preferences.get_devices()
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
def init_scene() -> None:
"""Resets the scene to a clean state.
Returns:
None
"""
# delete everything
for obj in bpy.data.objects:
bpy.data.objects.remove(obj, do_unlink=True)
# delete all the materials
for material in bpy.data.materials:
bpy.data.materials.remove(material, do_unlink=True)
# delete all the textures
for texture in bpy.data.textures:
bpy.data.textures.remove(texture, do_unlink=True)
# delete all the images
for image in bpy.data.images:
bpy.data.images.remove(image, do_unlink=True)
def init_camera():
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
bpy.context.collection.objects.link(cam)
bpy.context.scene.camera = cam
cam.data.sensor_height = cam.data.sensor_width = 32
cam_constraint = cam.constraints.new(type='TRACK_TO')
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
cam_constraint.up_axis = 'UP_Y'
cam_empty = bpy.data.objects.new("Empty", None)
cam_empty.location = (0, 0, 0)
bpy.context.scene.collection.objects.link(cam_empty)
cam_constraint.target = cam_empty
return cam
def init_uniform_lighting():
# Clear existing lights
bpy.ops.object.select_all(action="DESELECT")
bpy.ops.object.select_by_type(type="LIGHT")
bpy.ops.object.delete()
# Create environment light
if bpy.context.scene.world is None:
world = bpy.data.worlds.new("World")
bpy.context.scene.world = world
else:
world = bpy.context.scene.world
# Enabling nodes
world.use_nodes = True
node_tree = world.node_tree
nodes = node_tree.nodes
links = node_tree.links
# Remove default nodes
for node in nodes:
nodes.remove(node)
# Create background node
bg_node = nodes.new(type="ShaderNodeBackground")
bg_node.inputs["Color"].default_value = (1.0, 1.0, 1.0, 1.0)
bg_node.inputs["Strength"].default_value = 1.0
output_node = nodes.new(type="ShaderNodeOutputWorld")
links.new(bg_node.outputs["Background"], output_node.inputs["Surface"])
def init_random_lighting(camera_dir: np.ndarray) -> None:
# Clear existing lights
bpy.ops.object.select_all(action="DESELECT")
bpy.ops.object.select_by_type(type="LIGHT")
bpy.ops.object.delete()
# Create environment light
if bpy.context.scene.world is None:
world = bpy.data.worlds.new("World")
bpy.context.scene.world = world
else:
world = bpy.context.scene.world
# Enabling nodes
world.use_nodes = True
node_tree = world.node_tree
nodes = node_tree.nodes
links = node_tree.links
# Remove default nodes
for node in nodes:
nodes.remove(node)
# Random place lights
num_lights = np.random.randint(1, 4)
total_strength = 1.5
for i in range(num_lights):
new_light = bpy.data.objects.new(f"Light_{i}", bpy.data.lights.new(f"Light_{i}", type="POINT"))
bpy.context.collection.objects.link(new_light)
new_light_distance = 1 / np.random.uniform(1/100, 1/10)
new_light_dir = np.random.randn(3)
new_light_dir[2] += 0.6
new_light_dir = new_light_dir / np.linalg.norm(new_light_dir)
new_light_location = new_light_dir * new_light_distance
new_light_camera_strength_ratio = max(np.sum(camera_dir * new_light_dir) * 0.5 + 0.5, 0)
new_light_max_energy = total_strength / (np.sum(camera_dir * new_light_dir) * 0.45 + 0.55)
new_light_strength = np.sqrt(np.random.uniform(0.01, 1)) * new_light_max_energy
new_light_camera_strength = new_light_camera_strength_ratio * new_light_strength
total_strength -= new_light_camera_strength
new_light.location = (new_light_location[0], new_light_location[1], new_light_location[2])
new_light.data.color = (1.0, 1.0, 1.0)
new_light.data.energy = new_light_strength * new_light_distance**2 * 31.4
new_light.data.shadow_soft_size = np.random.uniform(0.1, 0.1 * new_light_distance)
# Create background node
bg_node = nodes.new(type="ShaderNodeBackground")
bg_node.inputs["Color"].default_value = (1.0, 1.0, 1.0, 1.0)
bg_node.inputs["Strength"].default_value = total_strength
output_node = nodes.new(type="ShaderNodeOutputWorld")
links.new(bg_node.outputs["Background"], output_node.inputs["Surface"])
def load_object(object_path: str) -> None:
"""Loads a model with a supported file extension into the scene.
Args:
object_path (str): Path to the model file.
Raises:
ValueError: If the file extension is not supported.
Returns:
None
"""
file_extension = object_path.split(".")[-1].lower()
if file_extension is None:
raise ValueError(f"Unsupported file type: {object_path}")
if file_extension == "usdz":
# install usdz io package
dirname = os.path.dirname(os.path.realpath(__file__))
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
bpy.ops.preferences.addon_install(filepath=usdz_package)
# enable it
addon_name = "io_scene_usdz"
bpy.ops.preferences.addon_enable(module=addon_name)
# import the usdz
from io_scene_usdz.import_usdz import import_usdz
import_usdz(context, filepath=object_path, materials=True, animations=True)
return None
# load from existing import functions
import_function = IMPORT_FUNCTIONS[file_extension]
print(f"Loading object from {object_path}")
if file_extension == "blend":
import_function(directory=object_path, link=False)
elif file_extension in {"glb", "gltf"}:
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
else:
import_function(filepath=object_path)
def delete_invisible_objects() -> None:
"""Deletes all invisible objects in the scene.
Returns:
None
"""
# bpy.ops.object.mode_set(mode="OBJECT")
bpy.ops.object.select_all(action="DESELECT")
for obj in bpy.context.scene.objects:
if obj.hide_viewport or obj.hide_render:
obj.hide_viewport = False
obj.hide_render = False
obj.hide_select = False
obj.select_set(True)
bpy.ops.object.delete()
# Delete invisible collections
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
for col in invisible_collections:
bpy.data.collections.remove(col)
def unhide_all_objects() -> None:
"""Unhides all objects in the scene.
Returns:
None
"""
for obj in bpy.context.scene.objects:
obj.hide_set(False)
def convert_to_meshes() -> None:
"""Converts all objects in the scene to meshes.
Returns:
None
"""
bpy.ops.object.select_all(action="DESELECT")
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
for obj in bpy.context.scene.objects:
obj.select_set(True)
bpy.ops.object.convert(target="MESH")
def triangulate_meshes() -> None:
"""Triangulates all meshes in the scene.
Returns:
None
"""
bpy.ops.object.select_all(action="DESELECT")
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
bpy.context.view_layer.objects.active = objs[0]
for obj in objs:
obj.select_set(True)
bpy.ops.object.mode_set(mode="EDIT")
bpy.ops.mesh.reveal()
bpy.ops.mesh.select_all(action="SELECT")
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
bpy.ops.object.mode_set(mode="OBJECT")
bpy.ops.object.select_all(action="DESELECT")
def scene_bbox() -> Tuple[Vector, Vector]:
"""Returns the bounding box of the scene.
Taken from Shap-E rendering script
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
Returns:
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
"""
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
for obj in scene_meshes:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def normalize_scene() -> Tuple[float, Vector]:
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
at the origin.
Mostly taken from the Point-E / Shap-E rendering script
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
but fix for multiple root objects: (see bug report here:
https://github.com/openai/shap-e/pull/60).
Returns:
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
"""
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
if len(scene_root_objects) > 1:
# create an empty object to be used as a parent for all root objects
scene = bpy.data.objects.new("ParentEmpty", None)
bpy.context.scene.collection.objects.link(scene)
# parent all root objects to the empty object
for obj in scene_root_objects:
obj.parent = scene
else:
scene = scene_root_objects[0]
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
scene.scale = scene.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
scene.matrix_world.translation += offset
bpy.ops.object.select_all(action="DESELECT")
return scale, offset
def get_transform_matrix(obj: bpy.types.Object) -> list:
pos, rt, _ = obj.matrix_world.decompose()
rt = rt.to_matrix()
matrix = []
for ii in range(3):
a = []
for jj in range(3):
a.append(rt[ii][jj])
a.append(pos[ii])
matrix.append(a)
matrix.append([0, 0, 0, 1])
return matrix
def main(arg):
if arg.object.endswith(".blend"):
delete_invisible_objects()
else:
init_scene()
load_object(arg.object)
print('[INFO] Scene initialized.')
# normalize scene
scale, offset = normalize_scene()
print('[INFO] Scene normalized.')
# Initialize camera and lighting
cam = init_camera()
init_uniform_lighting()
print('[INFO] Camera and lighting initialized.')
# ============= Render conditional views =============
init_render(engine=arg.engine, resolution=arg.cond_resolution)
# Create a list of views
to_export = {
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
"scale": scale,
"offset": [offset.x, offset.y, offset.z],
"frames": []
}
views = json.loads(arg.cond_views)
for i, view in enumerate(views):
cam_dir = np.array([
np.cos(view['yaw']) * np.cos(view['pitch']),
np.sin(view['yaw']) * np.cos(view['pitch']),
np.sin(view['pitch'])
])
init_random_lighting(cam_dir)
cam.location = (
view['radius'] * cam_dir[0],
view['radius'] * cam_dir[1],
view['radius'] * cam_dir[2]
)
cam.data.lens = 16 / np.tan(view['fov'] / 2)
bpy.context.scene.render.filepath = os.path.join(arg.cond_output_folder, f'{i:03d}.png')
# Render the scene
bpy.ops.render.render(write_still=True)
bpy.context.view_layer.update()
# Save camera parameters
metadata = {
"file_path": f'{i:03d}.png',
"camera_angle_x": view['fov'],
"transform_matrix": get_transform_matrix(cam)
}
to_export["frames"].append(metadata)
# Save the camera parameters
with open(os.path.join(arg.cond_output_folder, 'transforms.json'), 'w') as f:
json.dump(to_export, f, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
parser.add_argument('--cond_views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
parser.add_argument('--cond_output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
parser.add_argument('--cond_resolution', type=int, default=1024, help='Resolution of the conditional images.')
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
argv = sys.argv[sys.argv.index("--") + 1:]
args = parser.parse_args(argv)
main(args)
\ No newline at end of file
import os
import shutil
import sys
import time
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
def update_metadata(path, opt):
if not os.path.exists(path):
return None
timestamp = str(int(time.time()))
os.makedirs(os.path.join(path, 'merged_records'), exist_ok=True)
os.makedirs(os.path.join(path, 'new_records'), exist_ok=True)
if opt.from_merged_records:
df_files = [f for f in os.listdir(os.path.join(path, 'merged_records')) if f.endswith('.csv')]
df_files = [f for f in df_files if int(f.split('_')[0]) >= opt.record_start]
else:
df_files = [f for f in os.listdir(os.path.join(path, 'new_records')) if f.startswith('part_') and f.endswith('.csv')]
df_parts = []
for f in df_files:
try:
df_parts.append(pd.read_csv(os.path.join(path, 'new_records', f)))
except Exception as e:
print(f"Failed to read {f}: {e}")
if len(df_parts) > 0:
if os.path.exists(os.path.join(path, 'metadata.csv')):
metadata = pd.read_csv(os.path.join(path, 'metadata.csv'))
else:
columns = df_parts[0].columns
metadata = pd.DataFrame(columns=columns)
metadata.set_index('sha256', inplace=True)
for df_part in df_parts:
if 'sha256' in df_part.columns:
df_part.set_index('sha256', inplace=True)
metadata = df_part.combine_first(metadata)
metadata.to_csv(os.path.join(path, 'metadata.csv'))
for f in df_files:
shutil.move(os.path.join(path, 'new_records', f), os.path.join(path, 'merged_records', f'{timestamp}_{f}'))
return metadata
else:
if os.path.exists(os.path.join(path, 'metadata.csv')):
return pd.read_csv(os.path.join(path, 'metadata.csv'))
return None
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--thumbnail_root', type=str, default=None,
help='Directory to save the thumbnail files')
parser.add_argument('--render_cond_root', type=str, default=None,
help='Directory to save the render condition files')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to save the mesh files')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to save the pbr files')
parser.add_argument('--dual_grid_root', type=str, default=None,
help='Directory to save the dual grid files')
parser.add_argument('--pbr_voxel_root', type=str, default=None,
help='Directory to save the pbr voxel files')
parser.add_argument('--ss_latent_root', type=str, default=None,
help='Directory to save the sparse structure latent files')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--pbr_latent_root', type=str, default=None,
help='Directory to save the pbr latent files')
parser.add_argument('--field', type=str, default='all',
help='Fields to process, separated by commas')
parser.add_argument('--from_file', action='store_true',
help='Build metadata from file instead of from records of processings.' +
'Useful when some processing fail to generate records but file already exists.')
parser.add_argument('--from_merged_records', action='store_true',
help='Build metadata from merged records')
parser.add_argument('--record_start', type=int)
parser.add_argument('--rebuild', action='store_true',
help='Rebuild metadata from scratch, ignore existing metadata.')
dataset_utils.add_args(parser)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.thumbnail_root = opt.thumbnail_root or opt.root
opt.render_cond_root = opt.render_cond_root or opt.root
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
opt.dual_grid_root = opt.dual_grid_root or opt.root
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
opt.ss_latent_root = opt.ss_latent_root or opt.root
opt.shape_latent_root = opt.shape_latent_root or opt.root
opt.pbr_latent_root = opt.pbr_latent_root or opt.root
os.makedirs(opt.root, exist_ok=True)
opt.field = opt.field.split(',')
# get file list
if os.path.exists(os.path.join(opt.root, 'metadata.csv')):
print('Loading previous metadata...')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv'))
else:
metadata = dataset_utils.get_metadata(**opt)
metadata.to_csv(os.path.join(opt.root, 'metadata.csv'), index=False)
# merge downloaded
downloaded_metadata = update_metadata(os.path.join(opt.download_root, 'raw'), opt)
# merge thumbnails
thumbnail_metadata = update_metadata(os.path.join(opt.thumbnail_root, 'thumbnails'), opt)
# merge aesthetic scores
aesthetic_score_metadata = update_metadata(os.path.join(opt.root, 'aesthetic_scores'), opt)
# merge render conditions
render_cond_metadata = update_metadata(os.path.join(opt.render_cond_root, 'renders_cond'), opt)
# merge mesh dumped
mesh_dumped_metadata = update_metadata(os.path.join(opt.mesh_dump_root, 'mesh_dumps'), opt)
# merge pbr dumped
pbr_dumped_metadata = update_metadata(os.path.join(opt.pbr_dump_root, 'pbr_dumps'), opt)
# merge asset stats
asset_stats_metadata = update_metadata(os.path.join(opt.root, 'asset_stats'), opt)
# merge dual grid
dual_grid_resolutions = []
for dir in os.listdir(opt.dual_grid_root):
if os.path.isdir(os.path.join(opt.dual_grid_root, dir)) and dir.startswith('dual_grid_'):
dual_grid_resolutions.append(int(dir.split('_')[-1]))
dual_grid_metadata = {}
for res in dual_grid_resolutions:
dual_grid_metadata[res] = update_metadata(os.path.join(opt.dual_grid_root, f'dual_grid_{res}'), opt)
# merge pbr voxelized
pbr_voxel_resolutions = []
for dir in os.listdir(opt.pbr_voxel_root):
if os.path.isdir(os.path.join(opt.pbr_voxel_root, dir)) and dir.startswith('pbr_voxels_'):
pbr_voxel_resolutions.append(int(dir.split('_')[-1]))
pbr_voxel_metadata = {}
for res in pbr_voxel_resolutions:
pbr_voxel_metadata[res] = update_metadata(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}'), opt)
# merge ss latents
ss_latent_models = []
if os.path.exists(os.path.join(opt.ss_latent_root, 'ss_latents')):
ss_latent_models = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
ss_latent_metadata = {}
for model in ss_latent_models:
ss_latent_metadata[model] = update_metadata(os.path.join(opt.ss_latent_root, f'ss_latents/{model}'), opt)
# merge shape latents
shape_latent_models = []
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents')):
shape_latent_models = os.listdir(os.path.join(opt.shape_latent_root, 'shape_latents'))
shape_latent_metadata = {}
for model in shape_latent_models:
shape_latent_metadata[model] = update_metadata(os.path.join(opt.shape_latent_root, f'shape_latents/{model}'), opt)
# merge pbr latents
pbr_latent_models = []
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents')):
pbr_latent_models = os.listdir(os.path.join(opt.pbr_latent_root, 'pbr_latents'))
pbr_latent_metadata = {}
for model in pbr_latent_models:
pbr_latent_metadata[model] = update_metadata(os.path.join(opt.pbr_latent_root, f'pbr_latents/{model}'), opt)
# statistics
num_downloaded = downloaded_metadata['local_path'].count() if downloaded_metadata is not None else 0
with open(os.path.join(opt.root, 'statistics.txt'), 'w') as f:
f.write('Statistics:\n')
f.write(f' - Number of assets: {len(metadata)}\n')
f.write(f' - Number of assets downloaded: {num_downloaded}\n')
if thumbnail_metadata is not None:
f.write(f' - Number of assets with thumbnails: {thumbnail_metadata["thumbnailed"].sum()}\n')
if aesthetic_score_metadata is not None:
f.write(f' - Number of assets with aesthetic scores: {aesthetic_score_metadata["aesthetic_score"].count()}\n')
if render_cond_metadata is not None:
f.write(f' - Number of assets with render conditions: {render_cond_metadata["cond_rendered"].count()}\n')
if mesh_dumped_metadata is not None:
f.write(f' - Number of assets with mesh dumped: {mesh_dumped_metadata["mesh_dumped"].sum()}\n')
if pbr_dumped_metadata is not None:
f.write(f' - Number of assets with PBR dumped: {pbr_dumped_metadata["pbr_dumped"].sum()}\n')
if asset_stats_metadata is not None:
f.write(f' - Number of assets with asset stats: {len(asset_stats_metadata)}\n')
if len(dual_grid_resolutions) != 0:
f.write(f' - Number of assets with dual grid:\n')
for res in dual_grid_resolutions:
if dual_grid_metadata[res] is not None:
f.write(f' - {res}: {dual_grid_metadata[res]["dual_grid_converted"].sum()}\n')
if len(pbr_voxel_resolutions) != 0:
f.write(f' - Number of assets with PBR voxelization:\n')
for res in pbr_voxel_resolutions:
if pbr_voxel_metadata[res] is not None:
f.write(f' - {res}: {pbr_voxel_metadata[res]["pbr_voxelized"].sum()}\n')
if len(ss_latent_models) != 0:
f.write(f' - Number of assets with sparse structure latents:\n')
for model in ss_latent_models:
if ss_latent_metadata[model] is not None:
f.write(f' - {model}: {ss_latent_metadata[model]["ss_latent_encoded"].sum()}\n')
if len(shape_latent_models) != 0:
f.write(f' - Number of assets with shape latents:\n')
for model in shape_latent_models:
if shape_latent_metadata[model] is not None:
f.write(f' - {model}: {shape_latent_metadata[model]["shape_latent_encoded"].sum()}\n')
if len(pbr_latent_models) != 0:
f.write(f' - Number of assets with PBR latents:\n')
for model in pbr_latent_models:
if pbr_latent_metadata[model] is not None:
f.write(f' - {model}: {pbr_latent_metadata[model]["pbr_latent_encoded"].sum()}\n')
with open(os.path.join(opt.root, 'statistics.txt'), 'r') as f:
print(f.read())
\ No newline at end of file
import os
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to download the objects')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--check_only', action='store_true',
help='Only check if the objects are already downloaded')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
os.makedirs(opt.root, exist_ok=True)
os.makedirs(opt.download_root, exist_ok=True)
os.makedirs(os.path.join(opt.download_root, 'raw', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'local_path' in metadata.columns:
metadata = metadata[metadata['local_path'].isna()]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
downloaded = dataset_utils.download(metadata, **opt)
downloaded.to_csv(os.path.join(opt.download_root, 'raw', 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
import sys
import importlib
import argparse
import pandas as pd
import numpy as np
import torch
import pickle
import o_voxel
from easydict import EasyDict as edict
from functools import partial
def _dual_grid_mesh(file, metadatum, mesh_dump_root, root):
sha256 = metadatum['sha256']
try:
pack = {'sha256': sha256}
data = None
for res in opt.resolution:
need_process = False
# check if already processed
if os.path.exists(os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz')):
try:
info = o_voxel.io.read_vxz_info(os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz'))
pack[f'dual_grid_converted_{res}'] = True
pack[f'dual_grid_size_{res}'] = info['num_voxel']
except Exception as e:
print(f'Error reading {sha256}.vxz: {e}')
need_process = True
else:
need_process = True
# process mesh
if need_process:
if data is None:
with open(os.path.join(mesh_dump_root, 'mesh_dumps', f'{sha256}.pickle'), 'rb') as f:
dump = pickle.load(f)
start = 0
vertices = []
faces = []
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
vertices.append(obj['vertices'])
faces.append(obj['faces'] + start)
start += len(obj['vertices'])
vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
data = {'vertices': vertices, 'faces': faces}
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
**data,
grid_size=res,
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
face_weight=1.0,
boundary_weight=0.2,
regularization_weight=1e-2,
timing=False,
)
dual_vertices = dual_vertices * res - voxel_indices
assert torch.all(dual_vertices >= -1e-3) and torch.all(dual_vertices <= 1+1e-3), 'dual_vertices out of range'
dual_vertices = torch.clamp(dual_vertices, 0, 1)
dual_vertices = (dual_vertices * 255).type(torch.uint8)
intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8)
o_voxel.io.write_vxz(
os.path.join(root, f'dual_grid_{res}', f'{sha256}.vxz'),
voxel_indices,
{'vertices': dual_vertices, 'intersected': intersected},
)
pack[f'dual_grid_converted_{res}'] = True
pack[f'dual_grid_size_{res}'] = len(dual_vertices)
return pack
except Exception as e:
print(f'Error processing {sha256}: {e}')
return {'sha256': sha256, 'error': str(e)}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to load mesh dumps')
parser.add_argument('--dual_grid_root', type=str, default=None,
help='Directory to save dual grids')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--resolution', type=str, default=256)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.resolution = [int(x) for x in opt.resolution.split(',')]
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
opt.dual_grid_root = opt.dual_grid_root or opt.root
for res in opt.resolution:
os.makedirs(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')).set_index('sha256'))
for res in opt.resolution:
if os.path.exists(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'metadata.csv')):
dual_grid_metadata = pd.read_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'metadata.csv')).set_index('sha256')
dual_grid_metadata = dual_grid_metadata.rename(columns={'dual_grid_converted': f'dual_grid_converted_{res}', 'dual_grid_size': f'dual_grid_size_{res}'})
metadata = metadata.combine_first(dual_grid_metadata)
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['mesh_dumped'] == True]
mask = np.zeros(len(metadata), dtype=bool)
for res in opt.resolution:
if f'dual_grid_converted_{res}' in metadata.columns:
mask |= metadata[f'dual_grid_converted_{res}'] != True
else:
mask[:] = True
break
metadata = metadata[mask]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_dual_grid_mesh, root=opt.dual_grid_root, mesh_dump_root=opt.mesh_dump_root)
dual_grids = dataset_utils.foreach_instance(metadata, None, func, max_workers=opt.max_workers, no_file=True, desc='Dual griding')
if 'error' in dual_grids.columns:
errors = dual_grids[dual_grids[f'error'].notna()]
with open('errors.txt', 'w') as f:
f.write('\n'.join(errors['sha256'].tolist()))
for res in opt.resolution:
if f'dual_grid_converted_{res}' in dual_grids.columns:
dual_grid_metadata = dual_grids[dual_grids[f'dual_grid_converted_{res}'] == True]
if len(dual_grid_metadata) > 0:
dual_grid_metadata = dual_grid_metadata[['sha256', f'dual_grid_converted_{res}', f'dual_grid_size_{res}']]
dual_grid_metadata = dual_grid_metadata.rename(columns={f'dual_grid_converted_{res}': 'dual_grid_converted', f'dual_grid_size_{res}': 'dual_grid_size'})
dual_grid_metadata.to_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{res}', 'new_records', f'part_{opt.rank}.csv'), index=False)
\ No newline at end of file
import os
import shutil
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
import numpy as np
import tempfile
BLENDER_LINK = 'https://ftp.halifax.rwth-aachen.de/blender/release/Blender4.5/blender-4.5.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
def _dump_mesh(file_path, metadatum, root):
sha256 = metadatum['sha256']
with tempfile.TemporaryDirectory() as tmp_dir:
temp_path = os.path.join(tmp_dir, f'{sha256}.pickle')
output_path = os.path.join(root, 'mesh_dumps', f'{sha256}.pickle')
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'dump_mesh.py'),
'--',
'--object', os.path.expanduser(file_path),
'--output_path', os.path.expanduser(temp_path)
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(temp_path):
shutil.move(temp_path, output_path)
return {'sha256': sha256, 'mesh_dumped': True}
else:
if os.path.exists(temp_path + '_error.txt'):
with open(temp_path + '_error.txt', 'r') as f:
error_msg = f.read()
raise ValueError(f'Failed to dump mesh. File {file_path}. Error message: {error_msg}')
else:
raise ValueError(f'Failed to dump mesh. File {file_path}.')
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--mesh_dump_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.mesh_dump_root = opt.mesh_dump_root or opt.root
os.makedirs(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'new_records'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'mesh_dumped' in metadata.columns:
metadata = metadata[metadata['mesh_dumped'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.mesh_dump_root, 'mesh_dumps'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.pickle')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'mesh_dumped': True})
print(f'Found {len(sha256_list)} dumped mesh')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_dump_mesh, root=opt.mesh_dump_root)
mesh_dumped = dataset_utils.foreach_instance(metadata, opt.download_root, func, max_workers=opt.max_workers, desc='Dumping mesh')
mesh_dumped = pd.concat([mesh_dumped, pd.DataFrame.from_records(records)])
mesh_dumped.to_csv(os.path.join(opt.mesh_dump_root, 'mesh_dumps', 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
import shutil
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
import numpy as np
import tempfile
BLENDER_LINK = 'https://ftp.halifax.rwth-aachen.de/blender/release/Blender4.5/blender-4.5.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-4.5.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
os.system(f'{BLENDER_PATH} -b --python {os.path.join(os.path.dirname(__file__), "blender_script", "install_pillow.py")}')
def _dump_pbr(file_path, metadatum, root):
sha256 = metadatum['sha256']
with tempfile.TemporaryDirectory() as tmp_dir:
temp_path = os.path.join(tmp_dir, f'{sha256}.pickle')
output_path = os.path.join(root, 'pbr_dumps', f'{sha256}.pickle')
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'dump_pbr.py'),
'--',
'--object', os.path.expanduser(file_path),
'--output_path', os.path.expanduser(temp_path)
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(temp_path):
shutil.move(temp_path, output_path)
return {'sha256': sha256, 'pbr_dumped': True}
else:
if os.path.exists(temp_path + '_error.txt'):
with open(temp_path + '_error.txt', 'r') as f:
error_msg = f.read()
raise ValueError(f'Failed to dump PBR. File {file_path}. Error message: {error_msg}')
else:
raise ValueError(f'Failed to dump PBR. File {file_path}.')
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
os.makedirs(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'new_records'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'pbr_dumped' in metadata.columns:
metadata = metadata[metadata['pbr_dumped'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.pbr_dump_root, 'pbr_dumps'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.pickle')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'pbr_dumped': True})
print(f'Found {len(sha256_list)} dumped PBRs')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_dump_pbr, root=opt.pbr_dump_root)
pbr_dumped = dataset_utils.foreach_instance(metadata, opt.download_root, func, max_workers=opt.max_workers, desc='Dumping PBR')
pbr_dumped = pd.concat([pbr_dumped, pd.DataFrame.from_records(records)])
pbr_dumped.to_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
import o_voxel
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
import trellis2.modules.sparse as sp
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--pbr_voxel_root', type=str, default=None,
help='Directory to save the pbr voxel files')
parser.add_argument('--pbr_latent_root', type=str, default=None,
help='Directory to save the pbr latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=1024,
help='Sparse voxel resolution')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
opt.pbr_latent_root = opt.pbr_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['pbr_voxelized'] == True]
if 'pbr_latent_encoded' in metadata.columns:
metadata = metadata[metadata['pbr_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
def check_sha256(sha256):
if os.path.exists(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz')):
coords = np.load(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz'))['coords']
records.append({'sha256': sha256, 'pbr_latent_encoded': True, 'pbr_latent_tokens': coords.shape[0]})
pbar.update()
executor.map(check_sha256, metadata['sha256'].values)
executor.shutdown(wait=True)
existing_sha256 = set(r['sha256'] for r in records)
print(f'Found {len(existing_sha256)} processed objects')
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
attrs = ['base_color', 'metallic', 'roughness', 'alpha']
coords, attr = o_voxel.io.read_vxz(
os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{opt.resolution}', f'{sha256}.vxz'),
num_threads=4
)
feats = torch.concat([attr[k] for k in attrs], dim=-1) / 255.0 * 2 - 1
x = sp.SparseTensor(
feats.float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
load_queue.put((sha256, x))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'pbr_latent_encoded': True, 'pbr_latent_tokens': pack['coords'].shape[0]})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, voxels = load_queue.get()
if voxels is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
num_voxels = voxels.feats.shape[0]
# NaN/Inf
if not (is_valid_sparse_tensor(voxels)):
print(f"[Skip] {sha256}: NaN/Inf in input")
continue
z = encoder(voxels.cuda())
torch.cuda.synchronize()
if not torch.isfinite(z.feats).all():
print(f"[Skip] {sha256}: Non-finite latent in z.feats")
clear_cuda_error()
continue
pack = {
'feats': z.feats.cpu().numpy().astype(np.float32),
'coords': z.coords[:, 1:].cpu().numpy().astype(np.uint8),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256} ({num_voxels} voxels): {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.pbr_latent_root, 'pbr_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
import o_voxel
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
import trellis2.modules.sparse as sp
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--dual_grid_root', type=str, default=None,
help='Directory to save the dual grids')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=1024,
help='Sparse voxel resolution')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.dual_grid_root = opt.dual_grid_root or opt.root
opt.shape_latent_root = opt.shape_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['dual_grid_converted'] == True]
if 'shape_latent_encoded' in metadata.columns:
metadata = metadata[metadata['shape_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
def check_sha256(sha256):
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz')):
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz'))['coords']
records.append({'sha256': sha256, 'shape_latent_encoded': True, 'shape_latent_tokens': coords.shape[0]})
pbar.update()
executor.map(check_sha256, metadata['sha256'].values)
executor.shutdown(wait=True)
existing_sha256 = set(r['sha256'] for r in records)
print(f'Found {len(existing_sha256)} processed objects')
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
coords, attr = o_voxel.io.read_vxz(
os.path.join(opt.dual_grid_root, f'dual_grid_{opt.resolution}', f'{sha256}.vxz'),
num_threads=4
)
vertices = sp.SparseTensor(
(attr['vertices'] / 255.0).float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
intersected = vertices.replace(torch.cat([
attr['intersected'] % 2,
attr['intersected'] // 2 % 2,
attr['intersected'] // 4 % 2,
], dim=-1).bool())
load_queue.put((sha256, vertices, intersected))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'shape_latent_encoded': True, 'shape_latent_tokens': pack['coords'].shape[0]})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, vertices, intersected = load_queue.get()
if vertices is None or intersected is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
num_voxels = vertices.feats.shape[0]
# NaN/Inf
if not (is_valid_sparse_tensor(vertices) and is_valid_sparse_tensor(intersected)):
print(f"[Skip] {sha256}: NaN/Inf in input")
continue
z = encoder(vertices.cuda(), intersected.cuda())
torch.cuda.synchronize()
if not torch.isfinite(z.feats).all():
print(f"[Skip] {sha256}: Non-finite latent in z.feats")
clear_cuda_error()
continue
pack = {
'feats': z.feats.cpu().numpy().astype(np.float32),
'coords': z.coords[:, 1:].cpu().numpy().astype(np.uint8),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256} ({num_voxels} voxels): {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.shape_latent_root, 'shape_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import json
import argparse
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis2.models as models
torch.set_grad_enabled(False)
def is_valid_sparse_tensor(tensor):
return torch.isfinite(tensor.feats).all() and torch.isfinite(tensor.coords).all()
def clear_cuda_error():
torch.cuda.synchronize()
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--shape_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--ss_latent_root', type=str, default=None,
help='Directory to save the shape latent files')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--resolution', type=int, default=64,
help='Sparse voxel resolution')
parser.add_argument('--shape_latent_name', type=str, default=None,
help='Name of the shape latent files')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str,
help='Root directory of models')
parser.add_argument('--enc_model', type=str,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
opt.shape_latent_root = opt.shape_latent_root or opt.root
opt.ss_latent_root = opt.ss_latent_root or opt.root
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}_{opt.resolution}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model.split("/")[-1]}_{opt.ckpt}_{opt.resolution}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name,'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.ss_latent_root,'ss_latents', latent_name, 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.ss_latent_root,'ss_latents', latent_name,'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['shape_latent_encoded'] == True]
if 'ss_latent_encoded' in metadata.columns:
metadata = metadata[metadata['ss_latent_encoded'] != True]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256_list = os.listdir(os.path.join(opt.ss_latent_root, 'ss_latents'))
sha256_list = [os.path.splitext(f)[0] for f in sha256_list if f.endswith('.npz')]
for sha256 in sha256_list:
records.append({'sha256': sha256, 'ss_latent_encoded': True})
print(f'Found {len(sha256_list)} processed objects')
metadata = metadata[~metadata['sha256'].isin(sha256_list)]
print(f'Processing {len(metadata)} objects...')
sha256s = list(metadata['sha256'].values)
load_queue = Queue(maxsize=32)
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
coords = np.load(os.path.join(opt.shape_latent_root, 'shape_latents', opt.shape_latent_name, f'{sha256}.npz'))['coords']
assert np.all(coords < opt.resolution), f"{sha256}: Invalid coords"
coords = torch.from_numpy(coords).long()
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
load_queue.put((sha256, ss))
except Exception as e:
print(f"[Loader Error] {sha256}: {e}")
load_queue.put((sha256, None))
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, 'ss_latent_encoded': True})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
try:
sha256, ss = load_queue.get()
if ss is None:
print(f"[Skip] {sha256}: Failed to load input")
continue
ss = ss.cuda()[None].float()
z = encoder(ss, sample_posterior=False)
torch.cuda.synchronize()
if not torch.isfinite(z).all():
print(f"[Skip] {sha256}: Non-finite latent")
clear_cuda_error()
continue
pack = {
'z': z[0].cpu().numpy(),
}
saver_executor.submit(saver, sha256, pack)
except Exception as e:
print(f"[Error] {sha256}: {e}")
clear_cuda_error()
continue
saver_executor.shutdown(wait=True)
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.ss_latent_root, 'ss_latents', latent_name, 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
import json
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import numpy as np
from utils import sphere_hammersley_sequence
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6 libxfixes3 libgl1')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
def _render_cond(file_path, metadatum, root, num_cond_views):
sha256 = metadatum['sha256']
# Build conditional view camera
yaws = []
pitchs = []
offset = (np.random.rand(), np.random.rand())
for i in range(num_cond_views):
y, p = sphere_hammersley_sequence(i, num_cond_views, offset)
yaws.append(y)
pitchs.append(p)
fov_min, fov_max = 10, 70
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
k_min = 1 / radius_max**2
k_max = 1 / radius_min**2
ks = np.random.uniform(k_min, k_max, (1000000,))
radius = [1 / np.sqrt(k) for k in ks]
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
cond_views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render_cond.py'),
'--',
'--object', os.path.expanduser(file_path),
'--cond_views', json.dumps(cond_views),
'--cond_resolution', '1024',
'--cond_output_folder', os.path.join(root, 'renders_cond', sha256),
'--engine', 'CYCLES',
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(os.path.join(root, 'renders_cond', sha256, 'transforms.json')):
return {'sha256': sha256, 'cond_rendered': True}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--download_root', type=str, default=None,
help='Directory to save the downloaded files')
parser.add_argument('--render_cond_root', type=str, default=None,
help='Directory to save the mesh dumps')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--num_cond_views', type=int, default=16,
help='Number of conditional views to render')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=8)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.download_root = opt.download_root or opt.root
opt.render_cond_root = opt.render_cond_root or opt.root
os.makedirs(os.path.join(opt.render_cond_root, 'renders_cond', 'new_records'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.download_root, 'raw', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.download_root, 'raw', 'metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.render_cond_root, 'renders_cond', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.render_cond_root, 'renders_cond', 'metadata.csv')).set_index('sha256'))
metadata = metadata.reset_index()
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'cond_rendered' in metadata.columns:
metadata = metadata[(metadata['cond_rendered'] != True)]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Filtering existing objects") as pbar:
def check_sha256(sha256):
if os.path.exists(os.path.join(opt.render_cond_root, 'renders_cond', sha256, 'transforms.json')):
records.append({'sha256': sha256, 'cond_rendered': True})
pbar.update()
executor.map(check_sha256, metadata['sha256'].values)
executor.shutdown(wait=True)
existing_sha256 = set(r['sha256'] for r in records)
metadata = metadata[~metadata['sha256'].isin(existing_sha256)]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_render_cond, root=opt.render_cond_root, num_cond_views=opt.num_cond_views)
cond_rendered = dataset_utils.foreach_instance(metadata, opt.render_cond_root, func, max_workers=opt.max_workers, desc='Rendering objects')
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
cond_rendered.to_csv(os.path.join(opt.render_cond_root, 'renders_cond', 'new_records', f'part_{opt.rank}.csv'), index=False)
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub[cli] open_clip_torch
from typing import *
import hashlib
import numpy as np
import cv2
def get_file_hash(file: str) -> str:
sha256 = hashlib.sha256()
# Read the file from the path
with open(file, "rb") as f:
# Update the hash with the file content
for byte_block in iter(lambda: f.read(4096), b""):
sha256.update(byte_block)
return sha256.hexdigest()
# ===============LOW DISCREPANCY SEQUENCES================
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
u, v = hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]
# ==============PLY IO===============
import struct
import re
import torch
def read_ply(filename):
"""
Read a PLY file and return vertices, triangle faces, and quad faces.
Args:
filename (str): The file path to read from.
Returns:
vertices (torch.Tensor): Tensor of shape [N, 3] containing vertex positions.
tris (torch.Tensor): Tensor of shape [M, 3] containing triangle face indices (empty if none).
quads (torch.Tensor): Tensor of shape [K, 4] containing quad face indices (empty if none).
"""
with open(filename, 'rb') as f:
# Read the header until 'end_header' is encountered
header_bytes = b""
while True:
line = f.readline()
if not line:
raise ValueError("PLY header not found")
header_bytes += line
if b"end_header" in line:
break
header = header_bytes.decode('utf-8')
# Determine if the file is in ASCII or binary format
is_ascii = "ascii" in header
# Extract the number of vertices and faces from the header using regex
vertex_match = re.search(r'element vertex (\d+)', header)
if vertex_match:
num_vertices = int(vertex_match.group(1))
else:
raise ValueError("Vertex count not found in header")
face_match = re.search(r'element face (\d+)', header)
if face_match:
num_faces = int(face_match.group(1))
else:
raise ValueError("Face count not found in header")
vertices = []
tris = []
quads = []
if is_ascii:
# For ASCII format, read each line of vertex data (each line contains 3 floats)
for _ in range(num_vertices):
line = f.readline().decode('utf-8').strip()
if not line:
continue
parts = line.split()
vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
# Read face data, where the first number indicates the number of vertices for the face
for _ in range(num_faces):
line = f.readline().decode('utf-8').strip()
if not line:
continue
parts = line.split()
count = int(parts[0])
indices = list(map(int, parts[1:]))
if count == 3:
tris.append(indices)
elif count == 4:
quads.append(indices)
else:
# Skip faces with other numbers of vertices (can be extended as needed)
pass
else:
# For binary format: read directly from the binary stream
# Each vertex consists of 3 floats (12 bytes per vertex)
for _ in range(num_vertices):
data = f.read(12)
if len(data) < 12:
raise ValueError("Insufficient vertex data")
v = struct.unpack('<fff', data)
vertices.append(v)
# Read face data from the binary stream
for _ in range(num_faces):
# First, read 1 byte indicating the number of vertices in the face
count_data = f.read(1)
if len(count_data) < 1:
raise ValueError("Failed to read face vertex count")
count = struct.unpack('<B', count_data)[0]
if count == 3:
data = f.read(12) # 3 * 4 bytes
if len(data) < 12:
raise ValueError("Insufficient data for triangle face")
indices = struct.unpack('<3i', data)
tris.append(indices)
elif count == 4:
data = f.read(16) # 4 * 4 bytes
if len(data) < 16:
raise ValueError("Insufficient data for quad face")
indices = struct.unpack('<4i', data)
quads.append(indices)
else:
# For faces with a different number of vertices, read count*4 bytes
data = f.read(count * 4)
# Skip or extend processing as needed
raise ValueError(f"Unsupported face with {count} vertices")
# Convert lists to torch.Tensor
vertices = torch.tensor(vertices, dtype=torch.float32)
tris = torch.tensor(tris, dtype=torch.int32) if len(tris) > 0 else torch.empty((0, 3), dtype=torch.int32)
quads = torch.tensor(quads, dtype=torch.int32) if len(quads) > 0 else torch.empty((0, 4), dtype=torch.int32)
return vertices, tris, quads
def write_ply(filename, vertices, tris, quads, ascii=False):
"""
Write a mesh to a PLY file, with the option to save in ASCII or binary format.
Args:
filename (str): The filename to write to.
vertices (torch.Tensor): [N, 3] The vertex positions.
tris (torch.Tensor): [M, 3] The triangle indices.
quads (torch.Tensor): [K, 4] The quad indices.
ascii (bool): If True, write in ASCII format. If False, write in binary format.
"""
# Convert torch tensors to numpy arrays
vertices = vertices.numpy()
tris = tris.numpy()
quads = quads.numpy()
# Prepare the header
num_vertices = len(vertices)
num_faces = len(tris) + len(quads)
# Vertex properties
vertex_header = "property float x\nproperty float y\nproperty float z"
# Face properties (the number of vertices per face is variable)
face_header = "property list uchar int vertex_index"
# Start writing the PLY header
header = f"ply\n"
header += f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}\n"
header += f"element vertex {num_vertices}\n"
header += vertex_header + "\n"
header += f"element face {num_faces}\n"
header += face_header + "\n"
header += "end_header\n"
# Open the file for writing
with open(filename, 'wb' if not ascii else 'w') as f:
# Write the header
f.write(header if ascii else header.encode('utf-8'))
# Write the vertex data
if ascii:
for v in vertices:
f.write(f"{v[0]} {v[1]} {v[2]}\n")
else:
for v in vertices:
f.write(struct.pack('<fff', *v))
# Write the face data
if ascii:
for tri in tris:
f.write(f"3 {tri[0]} {tri[1]} {tri[2]}\n")
for quad in quads:
f.write(f"4 {quad[0]} {quad[1]} {quad[2]} {quad[3]}\n")
else:
for tri in tris:
f.write(struct.pack('<B3i', 3, *tri)) # 3 indices for triangle
for quad in quads:
f.write(struct.pack('<B4i', 4, *quad)) # 4 indices for quad
# ==============IMAGE UTILS===============
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
num_images = len(images)
if nrow is None and ncol is None:
if aspect_ratio is not None:
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
else:
nrow = int(np.sqrt(num_images))
ncol = (num_images + nrow - 1) // nrow
elif nrow is None and ncol is not None:
nrow = (num_images + ncol - 1) // ncol
elif nrow is not None and ncol is None:
ncol = (num_images + nrow - 1) // nrow
else:
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
if images[0].ndim == 2:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
else:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
for i, img in enumerate(images):
row = i // ncol
col = i % ncol
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
return grid
def notes_on_image(img, notes=None):
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if notes is not None:
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"):
"""
Draw text on an image of the given resolution. The text is automatically wrapped
and scaled so that it fits completely within the image while preserving any explicit
line breaks and original spacing. Horizontal and vertical alignment can be controlled
via flags.
Parameters:
text (str): The input text. Newline characters and spacing are preserved.
resolution (tuple): The image resolution as (width, height).
max_size (float): The maximum font size.
h_align (str): Horizontal alignment. Options: "left", "center", "right".
v_align (str): Vertical alignment. Options: "top", "center", "bottom".
Returns:
numpy.ndarray: The resulting image (BGR format) with the text drawn.
"""
width, height = resolution
# Create a white background image
img = np.full((height, width, 3), 255, dtype=np.uint8)
# Set margins and compute available drawing area
margin = 10
avail_width = width - 2 * margin
avail_height = height - 2 * margin
# Choose OpenCV font and text thickness
font = cv2.FONT_HERSHEY_SIMPLEX
thickness = 1
# Ratio for additional spacing between lines (relative to the height of "A")
line_spacing_ratio = 0.5
def wrap_line(line, max_width, font, thickness, scale):
"""
Wrap a single line of text into multiple lines such that each line's
width (measured at the given scale) does not exceed max_width.
This function preserves the original spacing by splitting the line into tokens
(words and whitespace) using a regular expression.
Parameters:
line (str): The input text line.
max_width (int): Maximum allowed width in pixels.
font (int): OpenCV font identifier.
thickness (int): Text thickness.
scale (float): The current font scale.
Returns:
List[str]: A list of wrapped lines.
"""
# Split the line into tokens (words and whitespace), preserving spacing
tokens = re.split(r'(\s+)', line)
if not tokens:
return ['']
wrapped_lines = []
current_line = ""
for token in tokens:
candidate = current_line + token
candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0]
if candidate_width <= max_width:
current_line = candidate
else:
# If current_line is empty, the token itself is too wide;
# break the token character by character.
if current_line == "":
sub_token = ""
for char in token:
candidate_char = sub_token + char
if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width:
sub_token = candidate_char
else:
if sub_token:
wrapped_lines.append(sub_token)
sub_token = char
current_line = sub_token
else:
wrapped_lines.append(current_line)
current_line = token
if current_line:
wrapped_lines.append(current_line)
return wrapped_lines
def compute_text_block(scale):
"""
Wrap the entire text (splitting at explicit newline characters) using the
provided scale, and then compute the overall width and height of the text block.
Returns:
wrapped_lines (List[str]): The list of wrapped lines.
block_width (int): Maximum width among the wrapped lines.
block_height (int): Total height of the text block including spacing.
sizes (List[tuple]): A list of (width, height) for each wrapped line.
spacing (int): The spacing between lines (computed from the scaled "A" height).
"""
# Split text by explicit newlines
input_lines = text.splitlines() if text else ['']
wrapped_lines = []
for line in input_lines:
wrapped = wrap_line(line, avail_width, font, thickness, scale)
wrapped_lines.extend(wrapped)
sizes = []
for line in wrapped_lines:
(text_size, _) = cv2.getTextSize(line, font, scale, thickness)
sizes.append(text_size) # (width, height)
block_width = max((w for w, h in sizes), default=0)
# Use the height of "A" (at the current scale) to compute line spacing
base_height = cv2.getTextSize("A", font, scale, thickness)[0][1]
spacing = int(line_spacing_ratio * base_height)
block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0
return wrapped_lines, block_width, block_height, sizes, spacing
# Use binary search to find the maximum scale that allows the text block to fit
lo = 0.001
hi = max_size
eps = 0.001 # convergence threshold
best_scale = lo
best_result = None
while hi - lo > eps:
mid = (lo + hi) / 2
wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid)
# Ensure that both width and height constraints are met
if block_width <= avail_width and block_height <= avail_height:
best_scale = mid
best_result = (wrapped_lines, block_width, block_height, sizes, spacing)
lo = mid # try a larger scale
else:
hi = mid # reduce the scale
if best_result is None:
best_scale = 0.5
best_result = compute_text_block(best_scale)
wrapped_lines, block_width, block_height, sizes, spacing = best_result
# Compute starting y-coordinate based on vertical alignment flag
if v_align == "top":
y_top = margin
elif v_align == "center":
y_top = margin + (avail_height - block_height) // 2
elif v_align == "bottom":
y_top = margin + (avail_height - block_height)
else:
y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag
# For cv2.putText, the y coordinate represents the text baseline;
# so for the first line add its height.
y = y_top + (sizes[0][1] if sizes else 0)
# Draw each line with horizontal alignment based on the flag
for i, line in enumerate(wrapped_lines):
line_width, line_height = sizes[i]
if h_align == "left":
x = margin
elif h_align == "center":
x = margin + (avail_width - line_width) // 2
elif h_align == "right":
x = margin + (avail_width - line_width)
else:
x = margin # default to left if invalid flag
cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA)
y += line_height + spacing
return img
def save_image_with_notes(img, path, notes=None):
"""
Save an image with notes.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy().transpose(1, 2, 0)
if img.dtype == np.float32 or img.dtype == np.float64:
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = notes_on_image(img, notes)
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
import os
import copy
import sys
import importlib
import argparse
import pandas as pd
import pickle
import numpy as np
import torch
from easydict import EasyDict as edict
from functools import partial
import o_voxel
def _pbr_voxelize(file, metadatum, pbr_dump_root, root):
sha256 = metadatum['sha256']
try:
pack = {'sha256': sha256}
dump = None
for res in opt.resolution:
need_process = False
# check if already processed
if os.path.exists(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz')):
try:
info = o_voxel.io.read_vxz_info(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz'))
pack[f'pbr_voxelized_{res}'] = True
pack[f'num_pbr_voxels_{res}'] = info['num_voxel']
except Exception as e:
print(f'Error reading {sha256}.vxz: {e}')
need_process = True
else:
need_process = True
# process if necessary
if need_process:
if dump == None:
with open(os.path.join(pbr_dump_root, 'pbr_dumps', f'{sha256}.pickle'), 'rb') as f:
dump = pickle.load(f)
# Fix dump alpha map
for mat in dump['materials']:
if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE':
mat['alphaMode'] = 'BLEND'
dump['materials'].append({
"baseColorFactor": [0.8, 0.8, 0.8],
"alphaFactor": 1.0,
"metallicFactor": 0.0,
"roughnessFactor": 0.5,
"alphaMode": "OPAQUE",
"alphaCutoff": 0.5,
"baseColorTexture": None,
"alphaTexture": None,
"metallicTexture": None,
"roughnessTexture": None,
}) # append default material
dump['objects'] = [
obj for obj in dump['objects']
if obj['vertices'].size != 0 and obj['faces'].size != 0
]
vertices = torch.from_numpy(np.concatenate([obj['vertices'] for obj in dump['objects']], axis=0)).float()
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
for obj in dump['objects']:
obj['vertices'] = (torch.from_numpy(obj['vertices']).float() - center) * scale
obj['vertices'] = obj['vertices'].numpy()
obj['mat_ids'][obj['mat_ids'] == -1] = len(dump['materials']) - 1
assert np.all(obj['mat_ids'] >= 0), 'invalid mat_ids'
assert np.all(obj['vertices'] >= -0.5) and np.all(obj['vertices'] <= 0.5), 'vertices out of range'
coord, attr = o_voxel.convert.blender_dump_to_volumetric_attr(dump, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
mip_level_offset=0, verbose=False, timing=False)
del attr['normal']
del attr['emissive']
o_voxel.io.write_vxz(os.path.join(root, f'pbr_voxels_{res}', f'{sha256}.vxz'), coord, attr)
pack[f'pbr_voxelized_{res}'] = True
pack[f'num_pbr_voxels_{res}'] = len(coord)
return pack
except Exception as e:
print(f'Error voxelizing {sha256}: {e}')
return {'sha256': sha256, 'error': str(e)}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--pbr_dump_root', type=str, default=None,
help='Directory to load mesh dumps')
parser.add_argument('--pbr_voxel_root', type=str, default=None,
help='Directory to save voxelized pbr attributes')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--resolution', type=str, default=1024)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=0)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
opt.resolution = sorted([int(x) for x in opt.resolution.split(',')], reverse=True)
opt.pbr_dump_root = opt.pbr_dump_root or opt.root
opt.pbr_voxel_root = opt.pbr_voxel_root or opt.root
for res in opt.resolution:
os.makedirs(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'new_records'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.root, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.root, 'metadata.csv')).set_index('sha256')
if os.path.exists(os.path.join(opt.root, 'aesthetic_scores', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.root, 'aesthetic_scores','metadata.csv')).set_index('sha256'))
if os.path.exists(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')):
metadata = metadata.combine_first(pd.read_csv(os.path.join(opt.pbr_dump_root, 'pbr_dumps', 'metadata.csv')).set_index('sha256'))
for res in opt.resolution:
if os.path.exists(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'metadata.csv')):
pbr_voxel_metadata = pd.read_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}','metadata.csv')).set_index('sha256')
pbr_voxel_metadata = pbr_voxel_metadata.rename(columns={'pbr_voxelized': f'pbr_voxelized_{res}', 'num_pbr_voxels': f'num_pbr_voxels_{res}'})
metadata = metadata.combine_first(pbr_voxel_metadata)
metadata = metadata.reset_index()
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['pbr_dumped'] == True]
mask = np.zeros(len(metadata), dtype=bool)
for res in opt.resolution:
if f'pbr_voxelized_{res}' in metadata.columns:
mask |= metadata[f'pbr_voxelized_{res}'] != True
else:
mask[:] = True
break
metadata = metadata[mask]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_pbr_voxelize, pbr_dump_root=opt.pbr_dump_root, root=opt.pbr_voxel_root)
pbr_voxelized = dataset_utils.foreach_instance(metadata, None, func, max_workers=opt.max_workers, no_file=True, desc='Voxelizing')
if 'error' in pbr_voxelized.columns:
errors = pbr_voxelized[pbr_voxelized['error'].notna()]
with open('errors.txt', 'w') as f:
f.write('\n'.join(errors['sha256'].tolist()))
for res in opt.resolution:
if f'pbr_voxelized_{res}' in pbr_voxelized.columns:
pbr_voxel_metadata = pbr_voxelized[pbr_voxelized[f'pbr_voxelized_{res}'] == True]
if len(pbr_voxel_metadata) > 0:
pbr_voxel_metadata = pbr_voxel_metadata[['sha256', f'pbr_voxelized_{res}', f'num_pbr_voxels_{res}']]
pbr_voxel_metadata = pbr_voxel_metadata.rename(columns={f'pbr_voxelized_{res}': 'pbr_voxelized', f'num_pbr_voxels_{res}': 'num_pbr_voxels'})
pbr_voxel_metadata.to_csv(os.path.join(opt.pbr_voxel_root, f'pbr_voxels_{res}', 'new_records', f'part_{opt.rank}.csv'), index=False)
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "garbage_collection_threshold:0.6,max_split_size_mb:128"
os.environ["HSA_XNACK"] = "1"
import cv2
import imageio
from PIL import Image
import torch
# Cap PyTorch to 90% of VRAM. On ROCm, exceeding 100% faults the GPU driver
# and hangs the display rather than raising a Python OOM exception.
# 90% leaves headroom for the display driver and system allocations.
torch.cuda.set_per_process_memory_fraction(0.90)
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.utils import render_utils
from trellis2.renderers import EnvMap
import o_voxel
# 1. Setup Environment Map
envmap = EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
))
# 2. Load Pipeline
pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
pipeline.cuda()
# 3. Load Image & Run
image = Image.open("assets/example_image/T.png")
mesh = pipeline.run(image)[0]
mesh.simplify(16777216) # nvdiffrast limit
# 4. Render Video - Disabled for ROCm systems.
#video = render_utils.make_pbr_vis_frames(render_utils.render_video(mesh, envmap=envmap))
#imageio.mimsave("sample.mp4", video, fps=15)
# 5. Export to GLB
glb = o_voxel.postprocess.to_glb(
vertices = mesh.vertices,
faces = mesh.faces,
attr_volume = mesh.attrs,
coords = mesh.coords,
attr_layout = mesh.layout,
voxel_size = mesh.voxel_size,
aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target = 1000000,
texture_size = 4096,
remesh = True,
remesh_band = 1,
remesh_project = 0,
verbose = True
)
glb.export("sample.glb", extension_webp=True)
\ No newline at end of file
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory
# --- STRIX HALO TRITON CHOKEPOINT (UNCHAINED) ---
import triton.runtime.jit
_original_run = triton.runtime.jit.JITFunction.run
def _amd_safe_triton_run(self, *args, **kwargs):
# 1. Clamp warps to 8 (8 * 64 AMD threads = 512 threads per block)
# This prevents the 2048-thread hardware rejection
if kwargs.get('num_warps', 1) > 8:
kwargs['num_warps'] = 8
# 2. The AMD Zero-Grid Trap Bypass
grid = kwargs.get('grid')
grid_val = grid(kwargs) if callable(grid) else grid
if grid_val and grid_val[0] == 0:
return
# Let everything else flow naturally
return _original_run(self, *args, **kwargs)
triton.runtime.jit.JITFunction.run = _amd_safe_triton_run
# ------------------------------------------------
import trimesh
from PIL import Image
from trellis2.pipelines import Trellis2TexturingPipeline
# 1. Load Pipeline
pipeline = Trellis2TexturingPipeline.from_pretrained("microsoft/TRELLIS.2-4B", config_file="texturing_pipeline.json")
pipeline.cuda()
# 2. Load Mesh, image & Run
mesh = trimesh.load("assets/example_texturing/the_forgotten_knight.ply")
image = Image.open("assets/example_texturing/image.webp")
output = pipeline.run(mesh, image)
# 3. Render Mesh
output.export("textured.glb", extension_webp=True)
\ No newline at end of file
"""
Visualization + render test for TRELLIS.2.
Run this instead of app.py to:
1. Generate a mesh with full pipeline visualizations saved at every stage.
2. Render the resulting mesh with render_utils.render_snapshot — the same
call that app.py uses — and save every render frame to disk.
This is the SMOKING GUN test: if the mesh looks correct in the decode-stage
visualizations but the render images look wrong (15-30% coverage), the bug
is inside the renderer / nvdiffrast path, not in the pipeline.
"""
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import numpy as np
import torch
from PIL import Image
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
# ---------------------------------------------------------------------------
# Config — edit these to taste
# ---------------------------------------------------------------------------
IMAGE_PATH = "assets/example_image/T2.png"
PIPELINE = "1024_cascade" # '512' | '1024' | '1024_cascade' | '1536_cascade'
SEED = 42
NVIEWS = 8 # number of render frames (matches app.py STEPS=8)
RENDER_RES = 1024 # render resolution (matches app.py)
VIZ_DIR = "visualizations_render_test"
# ---------------------------------------------------------------------------
def save_render_frames(images: dict, out_dir: str, prefix: str = "render"):
"""
Save every key × every frame from render_snapshot output to disk.
images is a dict like:
{'shaded': [np.uint8 (H,W,3), ...],
'normal': [...],
'base_color': [...],
'metallic': [...],
'roughness': [...],
'alpha': [...]}
"""
os.makedirs(out_dir, exist_ok=True)
saved = []
for key, frames in images.items():
for i, frame in enumerate(frames):
path = os.path.join(out_dir, f"{prefix}_{key}_{i:03d}.png")
img = Image.fromarray(frame)
img.save(path)
saved.append(path)
return saved
def make_contact_sheet(images: dict, out_dir: str, prefix: str = "contact"):
"""
Build one wide contact sheet per render key and save it.
Useful for seeing all views at once without opening 40 files.
"""
os.makedirs(out_dir, exist_ok=True)
paths = []
for key, frames in images.items():
if not frames:
continue
row = np.concatenate(frames, axis=1) # stack horizontally
path = os.path.join(out_dir, f"{prefix}_{key}_all_views.png")
Image.fromarray(row).save(path)
paths.append(path)
print(f" Contact sheet [{key}]: {row.shape[1]}x{row.shape[0]} -> {path}")
return paths
def export_obj(mesh, path: str):
"""
Export mesh vertices and faces directly to a Wavefront .obj file.
Uses NO nvdiffrast, NO flex_gemm, NO ROCm GPU ops — pure Python/numpy.
Load in Blender to verify 100% geometry completeness.
"""
verts = mesh.vertices.detach().cpu().numpy() # [N, 3]
faces = mesh.faces.detach().cpu().numpy() # [F, 3]
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "w") as f:
f.write("# TRELLIS.2 raw mesh export (no renderer, no GLB pipeline)\n")
f.write(f"# {verts.shape[0]} vertices, {faces.shape[0]} faces\n\n")
for v in verts:
f.write(f"v {v[0]:.6f} {v[1]:.6f} {v[2]:.6f}\n")
f.write("\n")
for tri in faces:
# .obj uses 1-based indices
f.write(f"f {tri[0]+1} {tri[1]+1} {tri[2]+1}\n")
print(f" OBJ export: {verts.shape[0]} vertices, {faces.shape[0]} faces -> {path}")
def main():
print("=" * 70)
print("TRELLIS.2 render smoke-test")
print("=" * 70)
# ------------------------------------------------------------------
# Load pipeline
# ------------------------------------------------------------------
print("\nLoading pipeline...")
pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
pipeline.cuda()
# ------------------------------------------------------------------
# Load HDR environment maps (same as app.py)
# ------------------------------------------------------------------
print("Loading environment maps...")
envmap = {
'forest': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED),
cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'sunset': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED),
cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'courtyard': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED),
cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
}
image = Image.open(IMAGE_PATH)
# ------------------------------------------------------------------
# Run pipeline with ALL stage visualizations enabled
# ------------------------------------------------------------------
print(f"\nRunning pipeline (type={PIPELINE}, seed={SEED}) ...")
print(f"Stage visualizations will be saved to: {VIZ_DIR}/\n")
mesh = pipeline.run(
image,
seed=SEED,
pipeline_type=PIPELINE,
visualize_sparse_structure=False,
visualize_save_dir=VIZ_DIR,
)
print("\nPipeline complete. Mesh object:", type(mesh[0]).__name__)
print(f" vertices : {mesh[0].vertices.shape}")
print(f" faces : {mesh[0].faces.shape}")
if hasattr(mesh[0], 'coords'):
print(f" vox coords: {mesh[0].coords.shape}")
if hasattr(mesh[0], 'attrs'):
print(f" vox attrs : {mesh[0].attrs.shape}")
# ------------------------------------------------------------------
# BYPASS TEST: export raw .obj — zero GPU rendering, zero nvdiffrast
# Load in Blender to verify geometry is 100% complete before blaming
# the renderer or GLB pipeline.
# ------------------------------------------------------------------
obj_path = os.path.join(VIZ_DIR, "raw_mesh.obj")
print(f"\n{'='*70}")
print("Exporting raw .obj (no renderer, no nvdiffrast) ...")
print(f"{'='*70}")
export_obj(mesh[0], obj_path)
print(f" -> Open {obj_path} in Blender to check geometry completeness.")
print(f" If this is complete but renders are 15-30%, the bug is in nvdiffrast/rasterizer.\n")
# ------------------------------------------------------------------
# SMOKING GUN: render with render_snapshot — identical to app.py
# ------------------------------------------------------------------
print(f"\n{'='*70}")
print(f"Rendering {NVIEWS} views at {RENDER_RES}x{RENDER_RES} ...")
print(f"{'='*70}\n")
render_out_dir = os.path.join(VIZ_DIR, "render_frames")
os.makedirs(render_out_dir, exist_ok=True)
images = render_utils.render_snapshot(
mesh[0], # same single-mesh call as app.py
resolution=RENDER_RES,
r=2,
fov=36,
nviews=NVIEWS,
envmap=envmap,
)
print(f"\nRender keys returned: {list(images.keys())}")
for key, frames in images.items():
print(f" {key}: {len(frames)} frames, each {frames[0].shape}")
# Save every individual frame
print(f"\nSaving individual frames to {render_out_dir}/ ...")
saved = save_render_frames(images, render_out_dir, prefix="render")
print(f" Saved {len(saved)} frame files.")
# Save contact sheets (one image per render key, all views side by side)
print(f"\nBuilding contact sheets ...")
make_contact_sheet(images, render_out_dir, prefix="contact")
print(f"\n{'='*70}")
print(f"All done. Check {render_out_dir}/ for render output.")
print(f" contact_shaded_all_views.png <- the key one to look at")
print(f" contact_base_color_all_views.png <- color without lighting")
print(f" contact_normal_all_views.png <- surface normals")
print(f"{'='*70}\n")
if __name__ == "__main__":
main()
# Read Arguments
TEMP=`getopt -o h --long help,basic,flash-attn,cumesh,o-voxel,flexgemm,nvdiffrast,nvdiffrec -n 'setup.sh' -- "$@"`
eval set -- "$TEMP"
HELP=false
BASIC=false
FLASHATTN=false
CUMESH=false
OVOXEL=false
FLEXGEMM=false
NVDIFFRAST=false
NVDIFFREC=false
ERROR=false
if [ "$#" -eq 1 ] ; then
HELP=true
fi
while true ; do
case "$1" in
-h|--help) HELP=true ; shift ;;
--basic) BASIC=true ; shift ;;
--flash-attn) FLASHATTN=true ; shift ;;
--cumesh) CUMESH=true ; shift ;;
--o-voxel) OVOXEL=true ; shift ;;
--flexgemm) FLEXGEMM=true ; shift ;;
--nvdiffrast) NVDIFFRAST=true ; shift ;;
--nvdiffrec) NVDIFFREC=true ; shift ;;
--) shift ; break ;;
*) ERROR=true ; break ;;
esac
done
if [ "$ERROR" = true ] ; then
echo "Error: Invalid argument"
HELP=true
fi
if [ "$HELP" = true ] ; then
echo "Usage: setup.sh [OPTIONS]"
echo "Options:"
echo " -h, --help Display this help message"
echo " --basic Install basic dependencies"
echo " --flash-attn Install flash-attention"
echo " --cumesh Install cumesh"
echo " --o-voxel Install o-voxel"
echo " --flexgemm Install flexgemm"
echo " --nvdiffrast Install nvdiffrast (CUDA) / nvdiffrast-hip (ROCm)"
echo " --nvdiffrec Install nvdiffrec (CUDA only)"
echo ""
echo " Activate your Python environment before running this script."
echo " For ROCm, ensure ROCm PyTorch is installed first:"
echo " pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/rocm6.2.4"
echo " For CUDA:"
echo " pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124"
return
fi
# Get system information
WORKDIR=$(pwd)
if command -v nvidia-smi > /dev/null; then
PLATFORM="cuda"
elif command -v rocminfo > /dev/null; then
PLATFORM="hip"
else
echo "Error: No supported GPU found"
exit 1
fi
if [ "$BASIC" = true ] ; then
pip install imageio imageio-ffmpeg tqdm easydict opencv-python-headless ninja trimesh "transformers==4.56.0" gradio==6.0.1 tensorboard pandas lpips zstandard pyfqmr matplotlib
pip install git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
sudo apt install -y libjpeg-dev
pip install pillow-simd
pip install kornia timm
fi
if [ "$FLASHATTN" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
pip install flash-attn==2.7.3
elif [ "$PLATFORM" = "hip" ] ; then
echo "[FLASHATTN] Prebuilt binaries not found. Building from source..."
mkdir -p /tmp/extensions
git clone --recursive https://github.com/ROCm/flash-attention.git /tmp/extensions/flash-attention
cd /tmp/extensions/flash-attention
git checkout tags/v2.7.3-cktile
GPU_ARCHS=gfx1201 python setup.py install
cd $WORKDIR
else
echo "[FLASHATTN] Unsupported platform: $PLATFORM"
fi
fi
if [ "$NVDIFFRAST" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
mkdir -p /tmp/extensions
git clone -b v0.4.0 https://github.com/NVlabs/nvdiffrast.git /tmp/extensions/nvdiffrast
pip install /tmp/extensions/nvdiffrast --no-build-isolation
elif [ "$PLATFORM" = "hip" ] ; then
mkdir -p /tmp/extensions
git clone https://github.com/Cardboard-box-a/nvdiffrast-hip.git /tmp/extensions/nvdiffrast-hip
pip install /tmp/extensions/nvdiffrast-hip --no-build-isolation
fi
fi
if [ "$NVDIFFREC" = true ] ; then
mkdir -p /tmp/extensions
git clone -b renderutils https://github.com/Cardboard-box-a/nvdiffrec.git /tmp/extensions/nvdiffrec
pip install /tmp/extensions/nvdiffrec --no-build-isolation
fi
if [ "$CUMESH" = true ] ; then
mkdir -p /tmp/extensions
git clone https://github.com/Cardboard-box-a/CuMesh.git /tmp/extensions/CuMesh --recursive
pip install /tmp/extensions/CuMesh --no-build-isolation
fi
if [ "$FLEXGEMM" = true ] ; then
mkdir -p /tmp/extensions
git clone https://github.com/Cardboard-box-a/FlexGEMM-rocm.git /tmp/extensions/FlexGEMM
pip install /tmp/extensions/FlexGEMM --no-build-isolation
fi
if [ "$OVOXEL" = true ] ; then
mkdir -p /tmp/extensions
git clone https://github.com/Cardboard-box-a/o-voxel.git /tmp/extensions/o-voxel --recursive
pip install /tmp/extensions/o-voxel --no-build-isolation
fi
import os
import sys
import json
import glob
import argparse
from easydict import EasyDict as edict
import torch
import torch.multiprocessing as mp
import numpy as np
import random
from trellis2 import models, datasets, trainers
from trellis2.utils.dist_utils import setup_dist
def find_ckpt(cfg):
# Load checkpoint
cfg['load_ckpt'] = None
if cfg.load_dir != '':
if cfg.ckpt == 'latest':
files = glob.glob(os.path.join(cfg.load_dir, 'ckpts', 'misc_*.pt'))
if len(files) != 0:
cfg.load_ckpt = max([
int(os.path.basename(f).split('step')[-1].split('.')[0])
for f in files
])
elif cfg.ckpt == 'none':
cfg.load_ckpt = None
else:
cfg.load_ckpt = int(cfg.ckpt)
return cfg
def setup_rng(rank):
torch.manual_seed(rank)
torch.cuda.manual_seed_all(rank)
np.random.seed(rank)
random.seed(rank)
def get_model_summary(model):
model_summary = 'Parameters:\n'
model_summary += '=' * 128 + '\n'
model_summary += f'{"Name":<{72}}{"Shape":<{32}}{"Type":<{16}}{"Grad"}\n'
num_params = 0
num_trainable_params = 0
for name, param in model.named_parameters():
model_summary += f'{name:<{72}}{str(param.shape):<{32}}{str(param.dtype):<{16}}{param.requires_grad}\n'
num_params += param.numel()
if param.requires_grad:
num_trainable_params += param.numel()
model_summary += '\n'
model_summary += f'Number of parameters: {num_params}\n'
model_summary += f'Number of trainable parameters: {num_trainable_params}\n'
return model_summary
def main(local_rank, cfg):
# Set up distributed training
rank = cfg.node_rank * cfg.num_gpus + local_rank
world_size = cfg.num_nodes * cfg.num_gpus
if world_size > 1:
setup_dist(rank, local_rank, world_size, cfg.master_addr, cfg.master_port)
# Seed rngs
setup_rng(rank)
# Load data
dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args)
# Build model
model_dict = {
name: getattr(models, model.name)(**model.args).cuda()
for name, model in cfg.models.items()
}
# Model summary
if rank == 0:
for name, backbone in model_dict.items():
model_summary = get_model_summary(backbone)
print(f'\n\nBackbone: {name}\n' + model_summary)
with open(os.path.join(cfg.output_dir, f'{name}_model_summary.txt'), 'w') as fp:
print(model_summary, file=fp)
# Build trainer
trainer = getattr(trainers, cfg.trainer.name)(model_dict, dataset, **cfg.trainer.args, output_dir=cfg.output_dir, load_dir=cfg.load_dir, step=cfg.load_ckpt)
# Train
if not cfg.tryrun:
if cfg.profile:
trainer.profile()
else:
trainer.run()
if __name__ == '__main__':
# Arguments and config
parser = argparse.ArgumentParser()
## config
parser.add_argument('--config', type=str, required=True, help='Experiment config file')
## io and resume
parser.add_argument('--output_dir', type=str, required=True, help='Output directory')
parser.add_argument('--load_dir', type=str, default='', help='Load directory, default to output_dir')
parser.add_argument('--ckpt', type=str, default='latest', help='Checkpoint step to resume training, default to latest')
parser.add_argument('--data_dir', type=str, default='./data/', help='Data directory')
parser.add_argument('--auto_retry', type=int, default=3, help='Number of retries on error')
## dubug
parser.add_argument('--tryrun', action='store_true', help='Try run without training')
parser.add_argument('--profile', action='store_true', help='Profile training')
## multi-node and multi-gpu
parser.add_argument('--num_nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--node_rank', type=int, default=0, help='Node rank')
parser.add_argument('--num_gpus', type=int, default=-1, help='Number of GPUs per node, default to all')
parser.add_argument('--master_addr', type=str, default='localhost', help='Master address for distributed training')
parser.add_argument('--master_port', type=str, default='12345', help='Port for distributed training')
opt = parser.parse_args()
opt.load_dir = opt.load_dir if opt.load_dir != '' else opt.output_dir
opt.num_gpus = torch.cuda.device_count() if opt.num_gpus == -1 else opt.num_gpus
## Load config
config = json.load(open(opt.config, 'r'))
## Combine arguments and config
cfg = edict()
cfg.update(opt.__dict__)
cfg.update(config)
print('\n\nConfig:')
print('=' * 80)
print(json.dumps(cfg.__dict__, indent=4))
# Prepare output directory
if cfg.node_rank == 0:
os.makedirs(cfg.output_dir, exist_ok=True)
## Save command and config
with open(os.path.join(cfg.output_dir, 'command.txt'), 'w') as fp:
print(' '.join(['python'] + sys.argv), file=fp)
with open(os.path.join(cfg.output_dir, 'config.json'), 'w') as fp:
json.dump(config, fp, indent=4)
# Run
if cfg.auto_retry == 0:
cfg = find_ckpt(cfg)
if cfg.num_gpus > 1:
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
else:
main(0, cfg)
else:
for rty in range(cfg.auto_retry):
try:
cfg = find_ckpt(cfg)
if cfg.num_gpus > 1:
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
else:
main(0, cfg)
break
except Exception as e:
print(f'Error: {e}')
print(f'Retrying ({rty + 1}/{cfg.auto_retry})...')
\ No newline at end of file
from . import models
from . import modules
from . import pipelines
from . import renderers
from . import representations
from . import utils
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment