experiment.py 22.9 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

""""
This file is the entry point for launching experiments with Implicitron.

Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
    and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
    pass, visualizations and metric printing

Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:

```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```

To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.

Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
  - Checkpoints:
        Model, optimizer and stats are stored in the directory
        named by the `exp_dir` key from the config file / CLI parameters.
  - Stats
        Stats are logged and plotted to the file "train_stats.pdf" in the
        same directory. The stats are also saved as part of the checkpoint file.
  - Visualizations
        Prredictions are plotted to a visdom server running at the
        port specified by the `visdom_server` and `visdom_port` keys in the
        config file.

"""
import copy
import json
import logging
import os
import random
import time
import warnings
from typing import Any, Dict, Optional, Tuple

import hydra
import lpips
import numpy as np
import torch
import tqdm
62
from accelerate import Accelerator
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
63
64
65
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
66
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
67
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
70
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
71
72
73
74
from pytorch3d.implicitron.models.renderer.multipass_ea import (
    MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
75
76
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
77
    expand_args_fields,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
78
79
80
81
82
    remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase

83
84
from .impl.experiment_config import ExperimentConfig
from .impl.optimization import init_optimizer
85

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
logger = logging.getLogger(__name__)

if version.parse(hydra.__version__) < version.Version("1.1"):
    raise ValueError(
        f"Hydra version {hydra.__version__} is too old."
        " (Implicitron requires version 1.1 or later.)"
    )

try:
    # only makes sense in FAIR cluster
    import pytorch3d.implicitron.fair_cluster.slurm  # noqa: F401
except ModuleNotFoundError:
    pass

100
101
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
102
103

def init_model(
104
    *,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
105
    cfg: DictConfig,
106
    accelerator: Optional[Accelerator] = None,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    force_load: bool = False,
    clear_stats: bool = False,
    load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
    """
    Returns an instance of `GenericModel`.

    If `cfg.resume` is set or `force_load` is true,
    attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
    will return the model with initial weights, unless `force_load` is passed,
    in which case a FileNotFoundError is raised.

    Args:
        force_load: If true, force load model from checkpoint even if
            cfg.resume is false.
        clear_stats: If true, clear the stats object loaded from checkpoint
        load_model_only: If true, load only the model weights from checkpoint
            and do not load the state of the optimizer and stats.

    Returns:
        model: The model with optionally loaded weights from checkpoint
        stats: The stats structure (optionally loaded from checkpoint)
        optimizer_state: The optimizer state dict containing
            `state` and `param_groups` keys (optionally loaded from checkpoint)

    Raise:
        FileNotFoundError if `force_load` is passed but checkpoint is not found.
    """

    # Initialize the model
    if cfg.architecture == "generic":
        model = GenericModel(**cfg.generic_model_args)
    else:
        raise ValueError(f"No such arch {cfg.architecture}.")

    # Determine the network outputs that should be logged
    if hasattr(model, "log_vars"):
        log_vars = copy.deepcopy(list(model.log_vars))
    else:
        log_vars = ["objective"]

    visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"

    # Init the stats struct
    stats = Stats(
        log_vars,
        visdom_env=visdom_env_charts,
        verbose=False,
        visdom_server=cfg.visdom_server,
        visdom_port=cfg.visdom_port,
    )

    # Retrieve the last checkpoint
    if cfg.resume_epoch > 0:
        model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
    else:
        model_path = model_io.find_last_checkpoint(cfg.exp_dir)

    optimizer_state = None
    if model_path is not None:
        logger.info("found previous model %s" % model_path)
        if force_load or cfg.resume:
            logger.info("   -> resuming")
170
171

            map_location = None
172
            if accelerator is not None and not accelerator.is_local_main_process:
173
174
175
                map_location = {
                    "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
                }
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
176
            if load_model_only:
177
178
179
                model_state_dict = torch.load(
                    model_io.get_model_path(model_path), map_location=map_location
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
180
181
182
                stats_load, optimizer_state = None, None
            else:
                model_state_dict, stats_load, optimizer_state = model_io.load_model(
183
                    model_path, map_location=map_location
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
                )

                # Determine if stats should be reset
                if not clear_stats:
                    if stats_load is None:
                        logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
                        last_epoch = model_io.parse_epoch_from_model_path(model_path)
                        logger.info(f"Estimated resume epoch = {last_epoch}")

                        # Reset the stats struct
                        for _ in range(last_epoch + 1):
                            stats.new_epoch()
                        assert last_epoch == stats.epoch
                    else:
                        stats = stats_load

                    # Update stats properties incase it was reset on load
                    stats.visdom_env = visdom_env_charts
                    stats.visdom_server = cfg.visdom_server
                    stats.visdom_port = cfg.visdom_port
                    stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
                    stats.synchronize_logged_vars(log_vars)
                else:
                    logger.info("   -> clearing stats")

            try:
                # TODO: fix on creation of the buffers
                # after the hack above, this will not pass in most cases
                # ... but this is fine for now
                model.load_state_dict(model_state_dict, strict=True)
            except RuntimeError as e:
                logger.error(e)
                logger.info("Cant load state dict in strict mode! -> trying non-strict")
                model.load_state_dict(model_state_dict, strict=False)
            model.log_vars = log_vars
        else:
            logger.info("   -> but not resuming -> starting from scratch")
    elif force_load:
        raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")

    return model, stats, optimizer_state


def trainvalidate(
    model,
    stats,
    epoch,
    loader,
    optimizer,
233
    validation: bool,
234
235
236
    *,
    accelerator: Optional[Accelerator],
    device: torch.device,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    bp_var: str = "objective",
    metric_print_interval: int = 5,
    visualize_interval: int = 100,
    visdom_env_root: str = "trainvalidate",
    clip_grad: float = 0.0,
    **kwargs,
) -> None:
    """
    This is the main loop for training and evaluation including:
    model forward pass, loss computation, backward pass and visualization.

    Args:
        model: The model module optionally loaded from checkpoint
        stats: The stats struct, also optionally loaded from checkpoint
        epoch: The index of the current epoch
        loader: The dataloader to use for the loop
        optimizer: The optimizer module optionally loaded from checkpoint
        validation: If true, run the loop with the model in eval mode
            and skip the backward pass
        bp_var: The name of the key in the model output `preds` dict which
            should be used as the loss for the backward pass.
        metric_print_interval: The batch interval at which the stats should be
            logged.
        visualize_interval: The batch interval at which the visualizations
            should be plotted
        visdom_env_root: The name of the visdom environment to use for plotting
        clip_grad: Optionally clip the gradient norms.
            If set to a value <=0.0, no clipping
        device: The device on which to run the model.

    Returns:
        None
    """

    if validation:
        model.eval()
        trainmode = "val"
    else:
        model.train()
        trainmode = "train"

    t_start = time.time()

    # get the visdom env name
    visdom_env_imgs = visdom_env_root + "_images_" + trainmode
    viz = vis_utils.get_visdom_connection(
        server=stats.visdom_server,
        port=stats.visdom_port,
    )

    # Iterate through the batches
    n_batches = len(loader)
289
    for it, net_input in enumerate(loader):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
290
291
292
        last_iter = it == n_batches - 1

        # move to gpu where possible (in place)
293
        net_input = net_input.to(device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        # run the forward pass
        if not validation:
            optimizer.zero_grad()
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
        else:
            with torch.no_grad():
                preds = model(
                    **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
                )

        # make sure we dont overwrite something
        assert all(k not in preds for k in net_input.keys())
        # merge everything into one big dict
        preds.update(net_input)

        # update the stats logger
        stats.update(preds, time_start=t_start, stat_set=trainmode)
        assert stats.it[trainmode] == it, "inconsistent stat iteration number!"

        # print textual status update
        if it % metric_print_interval == 0 or last_iter:
            stats.print(stat_set=trainmode, max_it=n_batches)

        # visualize results
319
        if (
320
            (accelerator is None or accelerator.is_local_main_process)
321
322
323
            and visualize_interval > 0
            and it % visualize_interval == 0
        ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"

            model.visualize(
                viz,
                visdom_env_imgs,
                preds,
                prefix,
            )

        # optimizer step
        if not validation:
            loss = preds[bp_var]
            assert torch.isfinite(loss).all(), "Non-finite loss!"
            # backprop
338
339
340
341
            if accelerator is None:
                loss.backward()
            else:
                accelerator.backward(loss)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
342
343
344
345
346
347
348
349
            if clip_grad > 0.0:
                # Optionally clip the gradient norms.
                total_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), clip_grad
                )
                if total_norm > clip_grad:
                    logger.info(
                        f"Clipping gradient: {total_norm}"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
350
                        + f" with coef {clip_grad / float(total_norm)}."
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
351
352
353
354
355
                    )

            optimizer.step()


356
def run_training(cfg: DictConfig) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
357
358
359
360
361
    """
    Entry point to run the training and validation loops
    based on the specified config file.
    """

362
    # Initialize the accelerator
363
364
365
366
367
368
369
    if no_accelerate:
        accelerator = None
        device = torch.device("cuda:0")
    else:
        accelerator = Accelerator(device_placement=False)
        logger.info(accelerator.state)
        device = accelerator.device
370
371
372

    logger.info(f"Running experiment on device: {device}")

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    # set the debug mode
    if cfg.detect_anomaly:
        logger.info("Anomaly detection!")
    torch.autograd.set_detect_anomaly(cfg.detect_anomaly)

    # create the output folder
    os.makedirs(cfg.exp_dir, exist_ok=True)
    _seed_all_random_engines(cfg.seed)
    remove_unused_components(cfg)

    # dump the exp config to the exp dir
    try:
        cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
        OmegaConf.save(config=cfg, f=cfg_filename)
    except PermissionError:
        warnings.warn("Cant dump config due to insufficient permissions!")

    # setup datasets
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
391
392
393
    datasource = ImplicitronDataSource(**cfg.data_source_args)
    datasets, dataloaders = datasource.get_datasets_and_dataloaders()
    task = datasource.get_task()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
394
395

    # init the model
396
    model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
397
398
399
    start_epoch = stats.epoch + 1

    # move model to gpu
400
    model.to(device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
401
402
403

    # only run evaluation on the test dataloader
    if cfg.eval_only:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
404
405
406
        _eval_and_dump(
            cfg,
            task,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
407
            datasource.all_train_cameras,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
408
409
410
411
412
413
            datasets,
            dataloaders,
            model,
            stats,
            device=device,
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        return

    # init the optimizer
    optimizer, scheduler = init_optimizer(
        model,
        optimizer_state=optimizer_state,
        last_epoch=start_epoch,
        **cfg.solver_args,
    )

    # check the scheduler and stats have been initialized correctly
    assert scheduler.last_epoch == stats.epoch + 1
    assert scheduler.last_epoch == start_epoch

428
429
430
    # Wrap all modules in the distributed library
    # Note: we don't pass the scheduler to prepare as it
    # doesn't need to be stepped at each optimizer step
431
432
433
434
435
436
437
438
439
    train_loader = dataloaders.train
    val_loader = dataloaders.val
    if accelerator is not None:
        (
            model,
            optimizer,
            train_loader,
            val_loader,
        ) = accelerator.prepare(model, optimizer, train_loader, val_loader)
440

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    past_scheduler_lrs = []
    # loop through epochs
    for epoch in range(start_epoch, cfg.solver_args.max_epochs):
        # automatic new_epoch and plotting of stats at every epoch start
        with stats:

            # Make sure to re-seed random generators to ensure reproducibility
            # even after restart.
            _seed_all_random_engines(cfg.seed + epoch)

            cur_lr = float(scheduler.get_last_lr()[-1])
            logger.info(f"scheduler lr = {cur_lr:1.2e}")
            past_scheduler_lrs.append(cur_lr)

            # train loop
            trainvalidate(
                model,
                stats,
                epoch,
460
                train_loader,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
461
462
463
464
                optimizer,
                False,
                visdom_env_root=vis_utils.get_visdom_env(cfg),
                device=device,
465
                accelerator=accelerator,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
466
467
468
469
                **cfg,
            )

            # val loop (optional)
470
            if val_loader is not None and epoch % cfg.validation_interval == 0:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
471
472
473
474
                trainvalidate(
                    model,
                    stats,
                    epoch,
475
                    val_loader,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
476
477
478
479
                    optimizer,
                    True,
                    visdom_env_root=vis_utils.get_visdom_env(cfg),
                    device=device,
480
                    accelerator=accelerator,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
481
482
483
484
485
                    **cfg,
                )

            # eval loop (optional)
            if (
486
                dataloaders.test is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
487
488
489
                and cfg.test_interval > 0
                and epoch % cfg.test_interval == 0
            ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
490
                _run_eval(
491
                    model,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
492
                    datasource.all_train_cameras,
493
494
495
496
                    dataloaders.test,
                    task,
                    camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
                    device=device,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
497
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
498
499
500
501

            assert stats.epoch == epoch, "inconsistent stats!"

            # delete previous models if required
502
            # save model only on the main process
503
504
505
            if cfg.store_checkpoints and (
                accelerator is None or accelerator.is_local_main_process
            ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
506
507
508
509
                if cfg.store_checkpoints_purge > 0:
                    for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
                        model_io.purge_epoch(cfg.exp_dir, prev_epoch)
                outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
510
511
512
                unwrapped_model = (
                    model if accelerator is None else accelerator.unwrap_model(model)
                )
513
514
515
                model_io.safe_save_model(
                    unwrapped_model, stats, outfile, optimizer=optimizer
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
516
517
518
519
520
521
522
523

            scheduler.step()

            new_lr = float(scheduler.get_last_lr()[-1])
            if new_lr != cur_lr:
                logger.info(f"LR change! {cur_lr} -> {new_lr}")

    if cfg.test_when_finished:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
524
525
526
        _eval_and_dump(
            cfg,
            task,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
527
            datasource.all_train_cameras,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
528
529
530
531
532
533
            datasets,
            dataloaders,
            model,
            stats,
            device=device,
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
534
535


536
def _eval_and_dump(
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
537
538
    cfg,
    task: Task,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
539
    all_train_cameras: Optional[CamerasBase],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
540
    datasets: DatasetMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
541
    dataloaders: DataLoaderMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
542
543
544
    model,
    stats,
    device,
545
) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
546
547
548
549
550
    """
    Run the evaluation loop with the test data loader and
    save the predictions to the `exp_dir`.
    """

551
552
553
    dataloader = dataloaders.test

    if dataloader is None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
554
        raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
555

556
557
558
559
560
561
562
563
    results = _run_eval(
        model,
        all_train_cameras,
        dataloader,
        task,
        camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
        device=device,
    )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589

    # add the evaluation epoch to the results
    for r in results:
        r["eval_epoch"] = int(stats.epoch)

    logger.info("Evaluation results")
    evaluate.pretty_print_nvs_metrics(results)

    with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
        json.dump(results, f)


def _get_eval_frame_data(frame_data):
    """
    Masks the unknown image data to make sure we cannot use it at model evaluation time.
    """
    frame_data_for_eval = copy.deepcopy(frame_data)
    is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
        frame_data.image_rgb
    )[:, None, None, None]
    for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
        value_masked = getattr(frame_data_for_eval, k).clone() * is_known
        setattr(frame_data_for_eval, k, value_masked)
    return frame_data_for_eval


590
591
592
593
594
595
596
597
def _run_eval(
    model,
    all_train_cameras,
    loader,
    task: Task,
    camera_difficulty_bin_breaks: Tuple[float, float],
    device,
):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
598
599
600
601
    """
    Run the evaluation loop on the test dataloader
    """
    lpips_model = lpips.LPIPS(net="vgg")
602
    lpips_model = lpips_model.to(device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
603
604
605
606
607
608

    model.eval()

    per_batch_eval_results = []
    logger.info("Evaluating model ...")
    for frame_data in tqdm.tqdm(loader):
609
        frame_data = frame_data.to(device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
610
611
612
613
614
615
616
617

        # mask out the unknown images so that the model does not see them
        frame_data_for_eval = _get_eval_frame_data(frame_data)

        with torch.no_grad():
            preds = model(
                **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
            )
618
619
620
621
622
623
624

            # TODO: Cannot use accelerate gather for two reasons:.
            # (1) TypeError: Can't apply _gpu_gather_one on object of type
            # <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
            # only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
            # (2) Same error above but for frame_data which contains Cameras.

625
            implicitron_render = copy.deepcopy(preds["implicitron_render"])
626

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
627
628
629
            per_batch_eval_results.append(
                evaluate.eval_batch(
                    frame_data,
630
                    implicitron_render,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
631
632
                    bg_color="black",
                    lpips_model=lpips_model,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
633
                    source_cameras=all_train_cameras,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
634
635
636
637
                )
            )

    _, category_result = evaluate.summarize_nvs_eval_results(
638
        per_batch_eval_results, task, camera_difficulty_bin_breaks
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
639
640
641
642
643
    )

    return category_result["results"]


644
def _seed_all_random_engines(seed: int) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
645
646
647
648
649
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)


650
def _setup_envvars_for_cluster() -> bool:
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    """
    Prepares to run on cluster if relevant.
    Returns whether FAIR cluster in use.
    """
    # TODO: How much of this is needed in general?

    try:
        import submitit
    except ImportError:
        return False

    try:
        # Only needed when launching on cluster with slurm and submitit
        job_env = submitit.JobEnvironment()
    except RuntimeError:
        return False

    os.environ["LOCAL_RANK"] = str(job_env.local_rank)
    os.environ["RANK"] = str(job_env.global_rank)
    os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "42918"
    logger.info(
        "Num tasks %s, global_rank %s"
        % (str(job_env.num_tasks), str(job_env.global_rank))
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
676
677
    )

678
    return True
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
679

680

681
expand_args_fields(ExperimentConfig)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
682
683
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
684
685
686
687


@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
    # CUDA_VISIBLE_DEVICES must have been set.

    if "CUDA_DEVICE_ORDER" not in os.environ:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

    if not _setup_envvars_for_cluster():
        logger.info("Running locally")

    # TODO: The following may be needed for hydra/submitit it to work
    expand_args_fields(GenericModel)
    expand_args_fields(AdaptiveRaySampler)
    expand_args_fields(MultiPassEmissionAbsorptionRenderer)
    expand_args_fields(ImplicitronDataSource)

    run_training(cfg)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
703
704
705
706


if __name__ == "__main__":
    experiment()