Unverified Commit 47b33464 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Shap-E: add support for mesh output (#4062)



* add output_type=mesh

* update img2img

* make style

* add doc

* make style

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* add docstring for output_type

* add a section in doc about hub mesh visualization/ rotation

* update conversion script so default background is white

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* renderer -> shap_e_renderer

* img2img renderer -> shap_e_renderer

* fix tests

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 07f1fbb1
...@@ -128,6 +128,63 @@ gif_path = export_to_gif(images[0], "burger_3d.gif") ...@@ -128,6 +128,63 @@ gif_path = export_to_gif(images[0], "burger_3d.gif")
``` ```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif) ![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif)
### Generate mesh
For both [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`], you can generate mesh output by passing `output_type` as `mesh` to the pipeline, and then use the [`ShapEPipeline.export_to_ply`] utility function to save the output as a `ply` file. We also provide a [`ShapEPipeline.export_to_obj`] function that you can use to save mesh outputs as `obj` files.
```python
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_ply
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo = "openai/shap-e"
pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to(device)
guidance_scale = 15.0
prompt = "A birthday cupcake"
images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type="mesh").images
ply_path = export_to_ply(images[0], "3d_cake.ply")
print(f"saved to folder: {ply_path}")
```
Huggingface Datasets supports mesh visualization for mesh files in `glb` format. Below we will show you how to convert your mesh file into `glb` format so that you can use the Dataset viewer to render 3D objects.
We need to install `trimesh` library.
```
pip install trimesh
```
To convert the mesh file into `glb` format,
```python
import trimesh
mesh = trimesh.load("3d_cake.ply")
mesh.export("3d_cake.glb", file_type="glb")
```
By default, the mesh output of Shap-E is from the bottom viewpoint; you can change the default viewpoint by applying a rotation transformation
```python
import trimesh
import numpy as np
mesh = trimesh.load("3d_cake.ply")
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
mesh = mesh.apply_transform(rot)
mesh.export("3d_cake.glb", file_type="glb")
```
Now you can upload your mesh file to your dataset and visualize it! Here is the link to the 3D cake we just generated
https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/shap_e/3d_cake.glb
## ShapEPipeline ## ShapEPipeline
[[autodoc]] ShapEPipeline [[autodoc]] ShapEPipeline
- all - all
......
This diff is collapsed.
...@@ -95,7 +95,7 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -95,7 +95,7 @@ class ShapEPipeline(DiffusionPipeline):
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]): scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding. A scheduler to be used in combination with `prior` to generate image embedding.
renderer ([`ShapERenderer`]): shap_e_renderer ([`ShapERenderer`]):
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
with the NeRF rendering method with the NeRF rendering method
""" """
...@@ -106,7 +106,7 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -106,7 +106,7 @@ class ShapEPipeline(DiffusionPipeline):
text_encoder: CLIPTextModelWithProjection, text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
scheduler: HeunDiscreteScheduler, scheduler: HeunDiscreteScheduler,
renderer: ShapERenderer, shap_e_renderer: ShapERenderer,
): ):
super().__init__() super().__init__()
...@@ -115,7 +115,7 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -115,7 +115,7 @@ class ShapEPipeline(DiffusionPipeline):
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
renderer=renderer, shap_e_renderer=shap_e_renderer,
) )
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
...@@ -149,7 +149,7 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -149,7 +149,7 @@ class ShapEPipeline(DiffusionPipeline):
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None hook = None
for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]: for cpu_offloaded_model in [self.text_encoder, self.prior, self.shap_e_renderer]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -218,7 +218,7 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -218,7 +218,7 @@ class ShapEPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0, guidance_scale: float = 4.0,
frame_size: int = 64, frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent output_type: Optional[str] = "pil", # pil, np, latent, mesh
return_dict: bool = True, return_dict: bool = True,
): ):
""" """
...@@ -248,8 +248,8 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -248,8 +248,8 @@ class ShapEPipeline(DiffusionPipeline):
frame_size (`int`, *optional*, default to 64): frame_size (`int`, *optional*, default to 64):
the width and height of each image frame of the generated 3d output the width and height of each image frame of the generated 3d output
output_type (`str`, *optional*, defaults to `"pt"`): output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`torch.Tensor`). (`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
...@@ -319,30 +319,39 @@ class ShapEPipeline(DiffusionPipeline): ...@@ -319,30 +319,39 @@ class ShapEPipeline(DiffusionPipeline):
sample=latents, sample=latents,
).prev_sample ).prev_sample
if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
)
if output_type == "latent": if output_type == "latent":
return ShapEPipelineOutput(images=latents) return ShapEPipelineOutput(images=latents)
images = [] images = []
for i, latent in enumerate(latents): if output_type == "mesh":
image = self.renderer.decode( for i, latent in enumerate(latents):
latent[None, :], mesh = self.shap_e_renderer.decode_to_mesh(
device, latent[None, :],
size=frame_size, device,
ray_batch_size=4096, )
n_coarse_samples=64, images.append(mesh)
n_fine_samples=128,
)
images.append(image)
images = torch.stack(images) else:
# np, pil
for i, latent in enumerate(latents):
image = self.shap_e_renderer.decode_to_image(
latent[None, :],
device,
size=frame_size,
)
images.append(image)
if output_type not in ["np", "pil"]: images = torch.stack(images)
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
images = images.cpu().numpy() images = images.cpu().numpy()
if output_type == "pil": if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images] images = [self.numpy_to_pil(image) for image in images]
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -94,7 +94,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -94,7 +94,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`HeunDiscreteScheduler`]): scheduler ([`HeunDiscreteScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding. A scheduler to be used in combination with `prior` to generate image embedding.
renderer ([`ShapERenderer`]): shap_e_renderer ([`ShapERenderer`]):
Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects
with the NeRF rendering method with the NeRF rendering method
""" """
...@@ -105,7 +105,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -105,7 +105,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModel, image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor, image_processor: CLIPImageProcessor,
scheduler: HeunDiscreteScheduler, scheduler: HeunDiscreteScheduler,
renderer: ShapERenderer, shap_e_renderer: ShapERenderer,
): ):
super().__init__() super().__init__()
...@@ -114,7 +114,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -114,7 +114,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
image_encoder=image_encoder, image_encoder=image_encoder,
image_processor=image_processor, image_processor=image_processor,
scheduler=scheduler, scheduler=scheduler,
renderer=renderer, shap_e_renderer=shap_e_renderer,
) )
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
...@@ -170,7 +170,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -170,7 +170,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 4.0, guidance_scale: float = 4.0,
frame_size: int = 64, frame_size: int = 64,
output_type: Optional[str] = "pil", # pil, np, latent output_type: Optional[str] = "pil", # pil, np, latent, mesh
return_dict: bool = True, return_dict: bool = True,
): ):
""" """
...@@ -200,8 +200,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -200,8 +200,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
frame_size (`int`, *optional*, default to 64): frame_size (`int`, *optional*, default to 64):
the width and height of each image frame of the generated 3d output the width and height of each image frame of the generated 3d output
output_type (`str`, *optional*, defaults to `"pt"`): output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`np.array`),`"latent"` (`torch.Tensor`), mesh ([`MeshDecoderOutput`]).
(`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
...@@ -275,32 +274,39 @@ class ShapEImg2ImgPipeline(DiffusionPipeline): ...@@ -275,32 +274,39 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
sample=latents, sample=latents,
).prev_sample ).prev_sample
if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
)
if output_type == "latent": if output_type == "latent":
return ShapEPipelineOutput(images=latents) return ShapEPipelineOutput(images=latents)
images = [] images = []
for i, latent in enumerate(latents): if output_type == "mesh":
print() for i, latent in enumerate(latents):
image = self.renderer.decode( mesh = self.shap_e_renderer.decode_to_mesh(
latent[None, :], latent[None, :],
device, device,
size=frame_size, )
ray_batch_size=4096, images.append(mesh)
n_coarse_samples=64,
n_fine_samples=128,
)
images.append(image)
images = torch.stack(images) else:
# np, pil
for i, latent in enumerate(latents):
image = self.shap_e_renderer.decode_to_image(
latent[None, :],
device,
size=frame_size,
)
images.append(image)
if output_type not in ["np", "pil"]: images = torch.stack(images)
raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}")
images = images.cpu().numpy() images = images.cpu().numpy()
if output_type == "pil": if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images] images = [self.numpy_to_pil(image) for image in images]
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -116,6 +116,101 @@ def integrate_samples(volume_range, ts, density, channels): ...@@ -116,6 +116,101 @@ def integrate_samples(volume_range, ts, density, channels):
return channels, weights, transmittance return channels, weights, transmittance
def volume_query_points(volume, grid_size):
indices = torch.arange(grid_size**3, device=volume.bbox_min.device)
zs = indices % grid_size
ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size
xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size
combined = torch.stack([xs, ys, zs], dim=1)
return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min
def _convert_srgb_to_linear(u: torch.Tensor):
return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4)
def _create_flat_edge_indices(
flat_cube_indices: torch.Tensor,
grid_size: Tuple[int, int, int],
):
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2]
y_offset = num_xs
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2]
z_offset = num_xs + num_ys
return torch.stack(
[
# Edges spanning x-axis.
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2],
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
+ flat_cube_indices[:, 2],
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1,
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2]
+ (flat_cube_indices[:, 1] + 1) * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1,
# Edges spanning y-axis.
(
y_offset
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
),
(
y_offset
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
),
(
y_offset
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1
),
(
y_offset
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2]
+ flat_cube_indices[:, 1] * grid_size[2]
+ flat_cube_indices[:, 2]
+ 1
),
# Edges spanning z-axis.
(
z_offset
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 1] * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1)
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
(
z_offset
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1)
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1)
+ flat_cube_indices[:, 2]
),
],
dim=-1,
)
class VoidNeRFModel(nn.Module): class VoidNeRFModel(nn.Module):
""" """
Implements the default empty space model where all queries are rendered as background. Implements the default empty space model where all queries are rendered as background.
...@@ -368,6 +463,141 @@ class ImportanceRaySampler(nn.Module): ...@@ -368,6 +463,141 @@ class ImportanceRaySampler(nn.Module):
return ts return ts
@dataclass
class MeshDecoderOutput(BaseOutput):
"""
A 3D triangle mesh with optional data at the vertices and faces.
Args:
verts (`torch.Tensor` of shape `(N, 3)`):
array of vertext coordinates
faces (`torch.Tensor` of shape `(N, 3)`):
array of triangles, pointing to indices in verts.
vertext_channels (Dict):
vertext coordinates for each color channel
"""
verts: torch.Tensor
faces: torch.Tensor
vertex_channels: Dict[str, torch.Tensor]
class MeshDecoder(nn.Module):
"""
Construct meshes from Signed distance functions (SDFs) using marching cubes method
"""
def __init__(self):
super().__init__()
cases = torch.zeros(256, 5, 3, dtype=torch.long)
masks = torch.zeros(256, 5, dtype=torch.bool)
self.register_buffer("cases", cases)
self.register_buffer("masks", masks)
def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor):
"""
For a signed distance field, produce a mesh using marching cubes.
:param field: a 3D tensor of field values, where negative values correspond
to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively.
:param min_point: a tensor of shape [3] containing the point corresponding
to (0, 0, 0) in the field.
:param size: a tensor of shape [3] containing the per-axis distance from the
(0, 0, 0) field corner and the (-1, -1, -1) field corner.
"""
assert len(field.shape) == 3, "input must be a 3D scalar field"
dev = field.device
cases = self.cases.to(dev)
masks = self.masks.to(dev)
min_point = min_point.to(dev)
size = size.to(dev)
grid_size = field.shape
grid_size_tensor = torch.tensor(grid_size).to(size)
# Create bitmasks between 0 and 255 (inclusive) indicating the state
# of the eight corners of each cube.
bitmasks = (field > 0).to(torch.uint8)
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1)
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2)
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4)
# Compute corner coordinates across the entire grid.
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype)
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[
:, None, None
]
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[
:, None
]
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype)
# Compute all vertices across all edges in the grid, even though we will
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices.
# These are all midpoints, and don't account for interpolation (which is
# done later based on the used edge midpoints).
edge_midpoints = torch.cat(
[
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3),
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3),
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3),
],
dim=0,
)
# Create a flat array of [X, Y, Z] indices for each cube.
cube_indices = torch.zeros(
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long
)
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None]
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None]
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev)
flat_cube_indices = cube_indices.reshape(-1, 3)
# Create a flat array mapping each cube to 12 global edge indices.
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size)
# Apply the LUT to figure out the triangles.
flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask
local_tris = cases[flat_bitmasks]
local_masks = masks[flat_bitmasks]
# Compute the global edge indices for the triangles.
global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape(
local_tris.shape
)
# Select the used triangles for each cube.
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)]
# Now we have a bunch of indices into the full list of possible vertices,
# but we want to reduce this list to only the used vertices.
used_vertex_indices = torch.unique(selected_tris.view(-1))
used_edge_midpoints = edge_midpoints[used_vertex_indices]
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long)
old_index_to_new_index[used_vertex_indices] = torch.arange(
len(used_vertex_indices), device=dev, dtype=torch.long
)
# Rewrite the triangles to use the new indices
faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape)
# Compute the actual interpolated coordinates corresponding to edge midpoints.
v1 = torch.floor(used_edge_midpoints).to(torch.long)
v2 = torch.ceil(used_edge_midpoints).to(torch.long)
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]]
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]]
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point
# The signs of s1 and s2 should be different. We want to find
# t such that t*s2 + (1-t)*s1 = 0.
t = (s1 / (s1 - s2))[:, None]
verts = t * p2 + (1 - t) * p1
return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None)
@dataclass @dataclass
class MLPNeRFModelOutput(BaseOutput): class MLPNeRFModelOutput(BaseOutput):
density: torch.Tensor density: torch.Tensor
...@@ -429,7 +659,7 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin): ...@@ -429,7 +659,7 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
return mapped_output return mapped_output
def forward(self, *, position, direction, ts, nerf_level="coarse"): def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"):
h = encode_position(position) h = encode_position(position)
h_preact = h h_preact = h
...@@ -455,10 +685,17 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin): ...@@ -455,10 +685,17 @@ class MLPNeRSTFModel(ModelMixin, ConfigMixin):
if nerf_level == "coarse": if nerf_level == "coarse":
h_density = activation["density_coarse"] h_density = activation["density_coarse"]
h_channels = activation["nerf_coarse"]
else: else:
h_density = activation["density_fine"] h_density = activation["density_fine"]
h_channels = activation["nerf_fine"]
if rendering_mode == "nerf":
if nerf_level == "coarse":
h_channels = activation["nerf_coarse"]
else:
h_channels = activation["nerf_fine"]
elif rendering_mode == "stf":
h_channels = activation["stf"]
density = self.density_activation(h_density) density = self.density_activation(h_density)
signed_distance = self.sdf_activation(activation["sdf"]) signed_distance = self.sdf_activation(activation["sdf"])
...@@ -583,6 +820,7 @@ class ShapERenderer(ModelMixin, ConfigMixin): ...@@ -583,6 +820,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at)
self.void = VoidNeRFModel(background=background, channel_scale=255.0) self.void = VoidNeRFModel(background=background, channel_scale=255.0)
self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0])
self.mesh_decoder = MeshDecoder()
@torch.no_grad() @torch.no_grad()
def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False):
...@@ -664,7 +902,7 @@ class ShapERenderer(ModelMixin, ConfigMixin): ...@@ -664,7 +902,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
return channels, weighted_sampler, model_out return channels, weighted_sampler, model_out
@torch.no_grad() @torch.no_grad()
def decode( def decode_to_image(
self, self,
latents, latents,
device, device,
...@@ -707,3 +945,106 @@ class ShapERenderer(ModelMixin, ConfigMixin): ...@@ -707,3 +945,106 @@ class ShapERenderer(ModelMixin, ConfigMixin):
images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0)
return images return images
@torch.no_grad()
def decode_to_mesh(
self,
latents,
device,
grid_size: int = 128,
query_batch_size: int = 4096,
texture_channels: Tuple = ("R", "G", "B"),
):
# 1. project the the paramters from the generated latents
projected_params = self.params_proj(latents)
# 2. update the mlp layers of the renderer
for name, param in self.mlp.state_dict().items():
if f"nerstf.{name}" in projected_params.keys():
param.copy_(projected_params[f"nerstf.{name}"].squeeze(0))
# 3. decoding with STF rendering
# 3.1 query the SDF values at vertices along a regular 128**3 grid
query_points = volume_query_points(self.volume, grid_size)
query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype)
fields = []
for idx in range(0, query_positions.shape[1], query_batch_size):
query_batch = query_positions[:, idx : idx + query_batch_size]
model_out = self.mlp(
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
)
fields.append(model_out.signed_distance)
# predicted SDF values
fields = torch.cat(fields, dim=1)
fields = fields.float()
assert (
len(fields.shape) == 3 and fields.shape[-1] == 1
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
fields = fields.reshape(1, *([grid_size] * 3))
# create grid 128 x 128 x 128
# - force a negative border around the SDFs to close off all the models.
full_grid = torch.zeros(
1,
grid_size + 2,
grid_size + 2,
grid_size + 2,
device=fields.device,
dtype=fields.dtype,
)
full_grid.fill_(-1.0)
full_grid[:, 1:-1, 1:-1, 1:-1] = fields
fields = full_grid
# apply a differentiable implementation of Marching Cubes to construct meshs
raw_meshes = []
mesh_mask = []
for field in fields:
raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min)
mesh_mask.append(True)
raw_meshes.append(raw_mesh)
mesh_mask = torch.tensor(mesh_mask, device=fields.device)
max_vertices = max(len(m.verts) for m in raw_meshes)
# 3.2. query the texture color head at each vertex of the resulting mesh.
texture_query_positions = torch.stack(
[m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes],
dim=0,
)
texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype)
textures = []
for idx in range(0, texture_query_positions.shape[1], query_batch_size):
query_batch = texture_query_positions[:, idx : idx + query_batch_size]
texture_model_out = self.mlp(
position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf"
)
textures.append(texture_model_out.channels)
# predict texture color
textures = torch.cat(textures, dim=1)
textures = _convert_srgb_to_linear(textures)
textures = textures.float()
# 3.3 augument the mesh with texture data
assert len(textures.shape) == 3 and textures.shape[-1] == len(
texture_channels
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1)))
return raw_meshes[0]
...@@ -103,7 +103,7 @@ if is_torch_available(): ...@@ -103,7 +103,7 @@ if is_torch_available():
) )
from .torch_utils import maybe_allow_in_graph from .torch_utils import maybe_allow_in_graph
from .testing_utils import export_to_gif, export_to_video from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
logger = get_logger(__name__) logger = get_logger(__name__)
......
import inspect import inspect
import io
import logging import logging
import multiprocessing import multiprocessing
import os import os
import random import random
import re import re
import struct
import tempfile import tempfile
import unittest import unittest
import urllib.parse import urllib.parse
from contextlib import contextmanager
from distutils.util import strtobool from distutils.util import strtobool
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
...@@ -315,6 +318,85 @@ def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> ...@@ -315,6 +318,85 @@ def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) ->
return output_gif_path return output_gif_path
@contextmanager
def buffered_writer(raw_f):
f = io.BufferedWriter(raw_f)
yield f
f.flush()
def export_to_ply(mesh, output_ply_path: str = None):
"""
Write a PLY file for a mesh.
"""
if output_ply_path is None:
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
coords = mesh.verts.detach().cpu().numpy()
faces = mesh.faces.cpu().numpy()
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
with buffered_writer(open(output_ply_path, "wb")) as f:
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if rgb is not None:
f.write(b"property uchar red\n")
f.write(b"property uchar green\n")
f.write(b"property uchar blue\n")
if faces is not None:
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if rgb is not None:
rgb = (rgb * 255.499).round().astype(int)
vertices = [
(*coord, *rgb)
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
format = struct.Struct("<3f3B")
for item in vertices:
f.write(format.pack(*item))
else:
format = struct.Struct("<3f")
for vertex in coords.tolist():
f.write(format.pack(*vertex))
if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist():
f.write(format.pack(len(tri), *tri))
return output_ply_path
def export_to_obj(mesh, output_obj_path: str = None):
if output_obj_path is None:
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
verts = mesh.verts.detach().cpu().numpy()
faces = mesh.faces.cpu().numpy()
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
vertices = [
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
]
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
combined_data = ["v " + vertex for vertex in vertices] + faces
with open(output_obj_path, "w") as f:
f.writelines("\n".join(combined_data))
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available(): if is_opencv_available():
import cv2 import cv2
......
...@@ -131,7 +131,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -131,7 +131,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prior = self.dummy_prior prior = self.dummy_prior
text_encoder = self.dummy_text_encoder text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer tokenizer = self.dummy_tokenizer
renderer = self.dummy_renderer shap_e_renderer = self.dummy_renderer
scheduler = HeunDiscreteScheduler( scheduler = HeunDiscreteScheduler(
beta_schedule="exp", beta_schedule="exp",
...@@ -145,7 +145,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -145,7 +145,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prior": prior, "prior": prior,
"text_encoder": text_encoder, "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"renderer": renderer, "shap_e_renderer": shap_e_renderer,
"scheduler": scheduler, "scheduler": scheduler,
} }
......
...@@ -143,7 +143,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -143,7 +143,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prior = self.dummy_prior prior = self.dummy_prior
image_encoder = self.dummy_image_encoder image_encoder = self.dummy_image_encoder
image_processor = self.dummy_image_processor image_processor = self.dummy_image_processor
renderer = self.dummy_renderer shap_e_renderer = self.dummy_renderer
scheduler = HeunDiscreteScheduler( scheduler = HeunDiscreteScheduler(
beta_schedule="exp", beta_schedule="exp",
...@@ -157,7 +157,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -157,7 +157,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prior": prior, "prior": prior,
"image_encoder": image_encoder, "image_encoder": image_encoder,
"image_processor": image_processor, "image_processor": image_processor,
"renderer": renderer, "shap_e_renderer": shap_e_renderer,
"scheduler": scheduler, "scheduler": scheduler,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment