experiment.py 22.6 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

""""
This file is the entry point for launching experiments with Implicitron.

Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
    and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
    pass, visualizations and metric printing

Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:

```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```

To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.

Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
  - Checkpoints:
        Model, optimizer and stats are stored in the directory
        named by the `exp_dir` key from the config file / CLI parameters.
  - Stats
        Stats are logged and plotted to the file "train_stats.pdf" in the
        same directory. The stats are also saved as part of the checkpoint file.
  - Visualizations
        Prredictions are plotted to a visdom server running at the
        port specified by the `visdom_server` and `visdom_port` keys in the
        config file.

"""
import copy
import json
import logging
import os
import random
import time
import warnings
from typing import Any, Dict, Optional, Tuple

import hydra
import lpips
import numpy as np
import torch
import tqdm
62
from accelerate import Accelerator
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
63
64
65
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
66
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
67
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
70
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
71
72
73
74
from pytorch3d.implicitron.models.renderer.multipass_ea import (
    MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
75
76
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
77
    expand_args_fields,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
78
79
80
81
82
    remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase

83
84
from .impl.experiment_config import ExperimentConfig
from .impl.optimization import init_optimizer
85

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
logger = logging.getLogger(__name__)

if version.parse(hydra.__version__) < version.Version("1.1"):
    raise ValueError(
        f"Hydra version {hydra.__version__} is too old."
        " (Implicitron requires version 1.1 or later.)"
    )

try:
    # only makes sense in FAIR cluster
    import pytorch3d.implicitron.fair_cluster.slurm  # noqa: F401
except ModuleNotFoundError:
    pass


def init_model(
    cfg: DictConfig,
    force_load: bool = False,
    clear_stats: bool = False,
    load_model_only: bool = False,
106
    accelerator: Accelerator = None,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
    """
    Returns an instance of `GenericModel`.

    If `cfg.resume` is set or `force_load` is true,
    attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
    will return the model with initial weights, unless `force_load` is passed,
    in which case a FileNotFoundError is raised.

    Args:
        force_load: If true, force load model from checkpoint even if
            cfg.resume is false.
        clear_stats: If true, clear the stats object loaded from checkpoint
        load_model_only: If true, load only the model weights from checkpoint
            and do not load the state of the optimizer and stats.

    Returns:
        model: The model with optionally loaded weights from checkpoint
        stats: The stats structure (optionally loaded from checkpoint)
        optimizer_state: The optimizer state dict containing
            `state` and `param_groups` keys (optionally loaded from checkpoint)

    Raise:
        FileNotFoundError if `force_load` is passed but checkpoint is not found.
    """

    # Initialize the model
    if cfg.architecture == "generic":
        model = GenericModel(**cfg.generic_model_args)
    else:
        raise ValueError(f"No such arch {cfg.architecture}.")

    # Determine the network outputs that should be logged
    if hasattr(model, "log_vars"):
        log_vars = copy.deepcopy(list(model.log_vars))
    else:
        log_vars = ["objective"]

    visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"

    # Init the stats struct
    stats = Stats(
        log_vars,
        visdom_env=visdom_env_charts,
        verbose=False,
        visdom_server=cfg.visdom_server,
        visdom_port=cfg.visdom_port,
    )

    # Retrieve the last checkpoint
    if cfg.resume_epoch > 0:
        model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
    else:
        model_path = model_io.find_last_checkpoint(cfg.exp_dir)

    optimizer_state = None
    if model_path is not None:
        logger.info("found previous model %s" % model_path)
        if force_load or cfg.resume:
            logger.info("   -> resuming")
167
168
169
170
171
172

            map_location = None
            if not accelerator.is_local_main_process:
                map_location = {
                    "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
                }
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
173
            if load_model_only:
174
175
176
                model_state_dict = torch.load(
                    model_io.get_model_path(model_path), map_location=map_location
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
177
178
179
                stats_load, optimizer_state = None, None
            else:
                model_state_dict, stats_load, optimizer_state = model_io.load_model(
180
                    model_path, map_location=map_location
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                )

                # Determine if stats should be reset
                if not clear_stats:
                    if stats_load is None:
                        logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
                        last_epoch = model_io.parse_epoch_from_model_path(model_path)
                        logger.info(f"Estimated resume epoch = {last_epoch}")

                        # Reset the stats struct
                        for _ in range(last_epoch + 1):
                            stats.new_epoch()
                        assert last_epoch == stats.epoch
                    else:
                        stats = stats_load

                    # Update stats properties incase it was reset on load
                    stats.visdom_env = visdom_env_charts
                    stats.visdom_server = cfg.visdom_server
                    stats.visdom_port = cfg.visdom_port
                    stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
                    stats.synchronize_logged_vars(log_vars)
                else:
                    logger.info("   -> clearing stats")

            try:
                # TODO: fix on creation of the buffers
                # after the hack above, this will not pass in most cases
                # ... but this is fine for now
                model.load_state_dict(model_state_dict, strict=True)
            except RuntimeError as e:
                logger.error(e)
                logger.info("Cant load state dict in strict mode! -> trying non-strict")
                model.load_state_dict(model_state_dict, strict=False)
            model.log_vars = log_vars
        else:
            logger.info("   -> but not resuming -> starting from scratch")
    elif force_load:
        raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")

    return model, stats, optimizer_state


def trainvalidate(
    model,
    stats,
    epoch,
    loader,
    optimizer,
230
    validation: bool,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
231
232
233
234
235
236
    bp_var: str = "objective",
    metric_print_interval: int = 5,
    visualize_interval: int = 100,
    visdom_env_root: str = "trainvalidate",
    clip_grad: float = 0.0,
    device: str = "cuda:0",
237
    accelerator: Accelerator = None,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    **kwargs,
) -> None:
    """
    This is the main loop for training and evaluation including:
    model forward pass, loss computation, backward pass and visualization.

    Args:
        model: The model module optionally loaded from checkpoint
        stats: The stats struct, also optionally loaded from checkpoint
        epoch: The index of the current epoch
        loader: The dataloader to use for the loop
        optimizer: The optimizer module optionally loaded from checkpoint
        validation: If true, run the loop with the model in eval mode
            and skip the backward pass
        bp_var: The name of the key in the model output `preds` dict which
            should be used as the loss for the backward pass.
        metric_print_interval: The batch interval at which the stats should be
            logged.
        visualize_interval: The batch interval at which the visualizations
            should be plotted
        visdom_env_root: The name of the visdom environment to use for plotting
        clip_grad: Optionally clip the gradient norms.
            If set to a value <=0.0, no clipping
        device: The device on which to run the model.

    Returns:
        None
    """

    if validation:
        model.eval()
        trainmode = "val"
    else:
        model.train()
        trainmode = "train"

    t_start = time.time()

    # get the visdom env name
    visdom_env_imgs = visdom_env_root + "_images_" + trainmode
    viz = vis_utils.get_visdom_connection(
        server=stats.visdom_server,
        port=stats.visdom_port,
    )

    # Iterate through the batches
    n_batches = len(loader)
285
    for it, net_input in enumerate(loader):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
286
287
288
        last_iter = it == n_batches - 1

        # move to gpu where possible (in place)
289
        net_input = net_input.to(accelerator.device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

        # run the forward pass
        if not validation:
            optimizer.zero_grad()
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
        else:
            with torch.no_grad():
                preds = model(
                    **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
                )

        # make sure we dont overwrite something
        assert all(k not in preds for k in net_input.keys())
        # merge everything into one big dict
        preds.update(net_input)

        # update the stats logger
        stats.update(preds, time_start=t_start, stat_set=trainmode)
        assert stats.it[trainmode] == it, "inconsistent stat iteration number!"

        # print textual status update
        if it % metric_print_interval == 0 or last_iter:
            stats.print(stat_set=trainmode, max_it=n_batches)

        # visualize results
315
316
317
318
319
        if (
            accelerator.is_local_main_process
            and visualize_interval > 0
            and it % visualize_interval == 0
        ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"

            model.visualize(
                viz,
                visdom_env_imgs,
                preds,
                prefix,
            )

        # optimizer step
        if not validation:
            loss = preds[bp_var]
            assert torch.isfinite(loss).all(), "Non-finite loss!"
            # backprop
334
            accelerator.backward(loss)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
335
336
337
338
339
340
341
342
            if clip_grad > 0.0:
                # Optionally clip the gradient norms.
                total_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), clip_grad
                )
                if total_norm > clip_grad:
                    logger.info(
                        f"Clipping gradient: {total_norm}"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
343
                        + f" with coef {clip_grad / float(total_norm)}."
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
344
345
346
347
348
                    )

            optimizer.step()


349
def run_training(cfg: DictConfig) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
350
351
352
353
354
    """
    Entry point to run the training and validation loops
    based on the specified config file.
    """

355
356
357
358
359
360
361
362
363
364
    # Initialize the accelerator
    accelerator = Accelerator(device_placement=False)
    logger.info(accelerator.state)

    device = accelerator.device
    logger.info(f"Running experiment on device: {device}")

    if accelerator.is_local_main_process:
        logger.info(OmegaConf.to_yaml(cfg))

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    # set the debug mode
    if cfg.detect_anomaly:
        logger.info("Anomaly detection!")
    torch.autograd.set_detect_anomaly(cfg.detect_anomaly)

    # create the output folder
    os.makedirs(cfg.exp_dir, exist_ok=True)
    _seed_all_random_engines(cfg.seed)
    remove_unused_components(cfg)

    # dump the exp config to the exp dir
    try:
        cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
        OmegaConf.save(config=cfg, f=cfg_filename)
    except PermissionError:
        warnings.warn("Cant dump config due to insufficient permissions!")

    # setup datasets
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
383
384
385
    datasource = ImplicitronDataSource(**cfg.data_source_args)
    datasets, dataloaders = datasource.get_datasets_and_dataloaders()
    task = datasource.get_task()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
386
    all_train_cameras = datasource.get_all_train_cameras()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
387
388

    # init the model
389
    model, stats, optimizer_state = init_model(cfg, accelerator=accelerator)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
390
391
392
    start_epoch = stats.epoch + 1

    # move model to gpu
393
    model.to(accelerator.device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
394
395
396

    # only run evaluation on the test dataloader
    if cfg.eval_only:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
397
398
399
400
401
402
403
404
405
        _eval_and_dump(
            cfg,
            task,
            all_train_cameras,
            datasets,
            dataloaders,
            model,
            stats,
            device=device,
406
            accelerator=accelerator,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
407
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        return

    # init the optimizer
    optimizer, scheduler = init_optimizer(
        model,
        optimizer_state=optimizer_state,
        last_epoch=start_epoch,
        **cfg.solver_args,
    )

    # check the scheduler and stats have been initialized correctly
    assert scheduler.last_epoch == stats.epoch + 1
    assert scheduler.last_epoch == start_epoch

422
423
424
425
426
427
428
429
430
431
    # Wrap all modules in the distributed library
    # Note: we don't pass the scheduler to prepare as it
    # doesn't need to be stepped at each optimizer step
    (
        model,
        optimizer,
        train_loader,
        val_loader,
    ) = accelerator.prepare(model, optimizer, dataloaders.train, dataloaders.val)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    past_scheduler_lrs = []
    # loop through epochs
    for epoch in range(start_epoch, cfg.solver_args.max_epochs):
        # automatic new_epoch and plotting of stats at every epoch start
        with stats:

            # Make sure to re-seed random generators to ensure reproducibility
            # even after restart.
            _seed_all_random_engines(cfg.seed + epoch)

            cur_lr = float(scheduler.get_last_lr()[-1])
            logger.info(f"scheduler lr = {cur_lr:1.2e}")
            past_scheduler_lrs.append(cur_lr)

            # train loop
            trainvalidate(
                model,
                stats,
                epoch,
451
                train_loader,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
452
453
454
455
                optimizer,
                False,
                visdom_env_root=vis_utils.get_visdom_env(cfg),
                device=device,
456
                accelerator=accelerator,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
457
458
459
460
                **cfg,
            )

            # val loop (optional)
461
            if val_loader is not None and epoch % cfg.validation_interval == 0:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
462
463
464
465
                trainvalidate(
                    model,
                    stats,
                    epoch,
466
                    val_loader,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
467
468
469
470
                    optimizer,
                    True,
                    visdom_env_root=vis_utils.get_visdom_env(cfg),
                    device=device,
471
                    accelerator=accelerator,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
472
473
474
475
476
                    **cfg,
                )

            # eval loop (optional)
            if (
477
                dataloaders.test is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
478
479
480
                and cfg.test_interval > 0
                and epoch % cfg.test_interval == 0
            ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
481
                _run_eval(
482
483
484
485
486
487
                    model,
                    all_train_cameras,
                    dataloaders.test,
                    task,
                    camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
                    device=device,
488
                    accelerator=accelerator,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
489
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
490
491
492
493

            assert stats.epoch == epoch, "inconsistent stats!"

            # delete previous models if required
494
495
            # save model only on the main process
            if cfg.store_checkpoints and accelerator.is_local_main_process:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
496
497
498
499
                if cfg.store_checkpoints_purge > 0:
                    for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
                        model_io.purge_epoch(cfg.exp_dir, prev_epoch)
                outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
500
501
502
503
                unwrapped_model = accelerator.unwrap_model(model)
                model_io.safe_save_model(
                    unwrapped_model, stats, outfile, optimizer=optimizer
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
504
505
506
507
508
509
510
511

            scheduler.step()

            new_lr = float(scheduler.get_last_lr()[-1])
            if new_lr != cur_lr:
                logger.info(f"LR change! {cur_lr} -> {new_lr}")

    if cfg.test_when_finished:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
512
513
514
515
516
517
518
519
520
521
        _eval_and_dump(
            cfg,
            task,
            all_train_cameras,
            datasets,
            dataloaders,
            model,
            stats,
            device=device,
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
522
523


524
def _eval_and_dump(
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
525
526
    cfg,
    task: Task,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
527
    all_train_cameras: Optional[CamerasBase],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
528
    datasets: DatasetMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
529
    dataloaders: DataLoaderMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
530
531
532
    model,
    stats,
    device,
533
    accelerator: Accelerator = None,
534
) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
535
536
537
538
539
    """
    Run the evaluation loop with the test data loader and
    save the predictions to the `exp_dir`.
    """

540
541
542
    dataloader = dataloaders.test

    if dataloader is None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
543
        raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
544

545
546
547
548
549
550
551
    results = _run_eval(
        model,
        all_train_cameras,
        dataloader,
        task,
        camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
        device=device,
552
        accelerator=accelerator,
553
    )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579

    # add the evaluation epoch to the results
    for r in results:
        r["eval_epoch"] = int(stats.epoch)

    logger.info("Evaluation results")
    evaluate.pretty_print_nvs_metrics(results)

    with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
        json.dump(results, f)


def _get_eval_frame_data(frame_data):
    """
    Masks the unknown image data to make sure we cannot use it at model evaluation time.
    """
    frame_data_for_eval = copy.deepcopy(frame_data)
    is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
        frame_data.image_rgb
    )[:, None, None, None]
    for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
        value_masked = getattr(frame_data_for_eval, k).clone() * is_known
        setattr(frame_data_for_eval, k, value_masked)
    return frame_data_for_eval


580
581
582
583
584
585
586
def _run_eval(
    model,
    all_train_cameras,
    loader,
    task: Task,
    camera_difficulty_bin_breaks: Tuple[float, float],
    device,
587
    accelerator: Accelerator = None,
588
):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
589
590
591
592
    """
    Run the evaluation loop on the test dataloader
    """
    lpips_model = lpips.LPIPS(net="vgg")
593
    lpips_model = lpips_model.to(accelerator.device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
594
595
596
597
598
599

    model.eval()

    per_batch_eval_results = []
    logger.info("Evaluating model ...")
    for frame_data in tqdm.tqdm(loader):
600
        frame_data = frame_data.to(accelerator.device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
601
602
603
604
605
606
607
608

        # mask out the unknown images so that the model does not see them
        frame_data_for_eval = _get_eval_frame_data(frame_data)

        with torch.no_grad():
            preds = model(
                **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
            )
609
610
611
612
613
614
615

            # TODO: Cannot use accelerate gather for two reasons:.
            # (1) TypeError: Can't apply _gpu_gather_one on object of type
            # <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
            # only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
            # (2) Same error above but for frame_data which contains Cameras.

616
            implicitron_render = copy.deepcopy(preds["implicitron_render"])
617

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
618
619
620
            per_batch_eval_results.append(
                evaluate.eval_batch(
                    frame_data,
621
                    implicitron_render,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
622
623
                    bg_color="black",
                    lpips_model=lpips_model,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
624
                    source_cameras=all_train_cameras,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
625
626
627
628
                )
            )

    _, category_result = evaluate.summarize_nvs_eval_results(
629
        per_batch_eval_results, task, camera_difficulty_bin_breaks
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
630
631
632
633
634
    )

    return category_result["results"]


635
def _seed_all_random_engines(seed: int) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
636
637
638
639
640
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)


641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
def _setup_envvars_for_cluster(cfg) -> bool:
    """
    Prepares to run on cluster if relevant.
    Returns whether FAIR cluster in use.
    """
    # TODO: How much of this is needed in general?

    try:
        import submitit
    except ImportError:
        return False

    try:
        # Only needed when launching on cluster with slurm and submitit
        job_env = submitit.JobEnvironment()
    except RuntimeError:
        return False

    os.environ["LOCAL_RANK"] = str(job_env.local_rank)
    os.environ["RANK"] = str(job_env.global_rank)
    os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "42918"
    logger.info(
        "Num tasks %s, global_rank %s"
        % (str(job_env.num_tasks), str(job_env.global_rank))
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
667
668
    )

669
    return True
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
670

671

672
expand_args_fields(ExperimentConfig)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
673
674
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
675
676
677
678


@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    # CUDA_VISIBLE_DEVICES must have been set.

    if "CUDA_DEVICE_ORDER" not in os.environ:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

    if not _setup_envvars_for_cluster():
        logger.info("Running locally")

    # TODO: The following may be needed for hydra/submitit it to work
    expand_args_fields(GenericModel)
    expand_args_fields(AdaptiveRaySampler)
    expand_args_fields(MultiPassEmissionAbsorptionRenderer)
    expand_args_fields(ImplicitronDataSource)

    run_training(cfg)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
694
695
696
697


if __name__ == "__main__":
    experiment()