Commit 37bd280d authored by Darijan Gudelj's avatar Darijan Gudelj Committed by Facebook GitHub Bot
Browse files

load whole dataset in train loop

Summary: Loads the whole dataset and moves it to the device and sends it to for sampling to enable full dataset heterogeneous raysampling.

Reviewed By: bottler

Differential Revision: D39263009

fbshipit-source-id: c527537dfc5f50116849656c9e171e868f6845b1
parent c311a4cb
...@@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13 ...@@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
train_loader=train_loader, train_loader=train_loader,
val_loader=val_loader, val_loader=val_loader,
test_loader=test_loader, test_loader=test_loader,
# pyre-ignore[6]
train_dataset=datasets.train, train_dataset=datasets.train,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
......
...@@ -22,7 +22,7 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -22,7 +22,7 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Dataset
from .utils import seed_all_random_engines from .utils import seed_all_random_engines
...@@ -44,6 +44,7 @@ class TrainingLoopBase(ReplaceableBase): ...@@ -44,6 +44,7 @@ class TrainingLoopBase(ReplaceableBase):
train_loader: DataLoader, train_loader: DataLoader,
val_loader: Optional[DataLoader], val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader], test_loader: Optional[DataLoader],
train_dataset: Dataset,
model: ImplicitronModelBase, model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: Any, scheduler: Any,
...@@ -116,6 +117,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): ...@@ -116,6 +117,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
train_loader: DataLoader, train_loader: DataLoader,
val_loader: Optional[DataLoader], val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader], test_loader: Optional[DataLoader],
train_dataset: Dataset,
model: ImplicitronModelBase, model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: Any, scheduler: Any,
......
...@@ -389,7 +389,8 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ...@@ -389,7 +389,8 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
) )
# (1) Sample rendering rays with the ray sampler. # (1) Sample rendering rays with the ray sampler.
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29] # pyre-ignore[29]
ray_bundle: ImplicitronRayBundle = self.raysampler(
target_cameras, target_cameras,
evaluation_mode, evaluation_mode,
mask=mask_crop[:n_targets] mask=mask_crop[:n_targets]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment