Commit 745aaf39 authored by Patrick Labatut's avatar Patrick Labatut Committed by Facebook GitHub Bot
Browse files

No side effect with invalid inputs to save_obj / save_ply

Summary: Do not create output files with invalid inputs to `save_{obj,ply}()`.

Reviewed By: bottler

Differential Revision: D20720282

fbshipit-source-id: 3b611a10da6f6eecacab2a1900bf16f89e2000aa
parent 83feed56
...@@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): ...@@ -526,6 +526,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
""" """
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
new_f = False new_f = False
if isinstance(f, str): if isinstance(f, str):
new_f = True new_f = True
...@@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): ...@@ -541,21 +549,14 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
# TODO (nikhilar) Speed up this function. # TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None): def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
if not (len(verts) or len(faces)): if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided") warnings.warn("Empty 'verts' and 'faces' arguments provided")
return return
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
raise ValueError(
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
raise ValueError(
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
verts, faces = verts.cpu(), faces.cpu() verts, faces = verts.cpu(), faces.cpu()
lines = "" lines = ""
......
...@@ -700,7 +700,7 @@ def load_ply(f): ...@@ -700,7 +700,7 @@ def load_ply(f):
return verts, faces return verts, faces
def _save_ply(f, verts, faces, decimal_places: Optional[int]): def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
""" """
Internal implementation for saving a mesh to a .ply file. Internal implementation for saving a mesh to a .ply file.
...@@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]): ...@@ -710,15 +710,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]):
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
""" """
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
raise ValueError( assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
raise ValueError(
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
print("ply\nformat ascii 1.0", file=f) print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f) print(f"element vertex {verts.shape[0]}", file=f)
...@@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None): ...@@ -760,6 +753,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
""" """
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
new_f = False new_f = False
if isinstance(f, str): if isinstance(f, str):
new_f = True new_f = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment