experiment.py 24.1 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

""""
This file is the entry point for launching experiments with Implicitron.

Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
    and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
    pass, visualizations and metric printing

Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:

```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```

To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.

Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
  - Checkpoints:
        Model, optimizer and stats are stored in the directory
        named by the `exp_dir` key from the config file / CLI parameters.
  - Stats
        Stats are logged and plotted to the file "train_stats.pdf" in the
        same directory. The stats are also saved as part of the checkpoint file.
  - Visualizations
        Prredictions are plotted to a visdom server running at the
        port specified by the `visdom_server` and `visdom_port` keys in the
        config file.

"""

import copy
import json
import logging
import os
import random
import time
import warnings
56
from dataclasses import field
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
57
58
59
60
61
62
63
64
65
66
from typing import Any, Dict, Optional, Tuple

import hydra
import lpips
import numpy as np
import torch
import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
67
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
69
from pytorch3d.implicitron.dataset.dataset_base import FrameData
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
70
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
71
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
72
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
73
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
74
75
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
76
    Configurable,
77
    enable_get_default_args,
78
    expand_args_fields,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
79
80
81
82
83
84
    get_default_args_field,
    remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase

85

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
logger = logging.getLogger(__name__)

if version.parse(hydra.__version__) < version.Version("1.1"):
    raise ValueError(
        f"Hydra version {hydra.__version__} is too old."
        " (Implicitron requires version 1.1 or later.)"
    )

try:
    # only makes sense in FAIR cluster
    import pytorch3d.implicitron.fair_cluster.slurm  # noqa: F401
except ModuleNotFoundError:
    pass


def init_model(
    cfg: DictConfig,
    force_load: bool = False,
    clear_stats: bool = False,
    load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
    """
    Returns an instance of `GenericModel`.

    If `cfg.resume` is set or `force_load` is true,
    attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
    will return the model with initial weights, unless `force_load` is passed,
    in which case a FileNotFoundError is raised.

    Args:
        force_load: If true, force load model from checkpoint even if
            cfg.resume is false.
        clear_stats: If true, clear the stats object loaded from checkpoint
        load_model_only: If true, load only the model weights from checkpoint
            and do not load the state of the optimizer and stats.

    Returns:
        model: The model with optionally loaded weights from checkpoint
        stats: The stats structure (optionally loaded from checkpoint)
        optimizer_state: The optimizer state dict containing
            `state` and `param_groups` keys (optionally loaded from checkpoint)

    Raise:
        FileNotFoundError if `force_load` is passed but checkpoint is not found.
    """

    # Initialize the model
    if cfg.architecture == "generic":
        model = GenericModel(**cfg.generic_model_args)
    else:
        raise ValueError(f"No such arch {cfg.architecture}.")

    # Determine the network outputs that should be logged
    if hasattr(model, "log_vars"):
        log_vars = copy.deepcopy(list(model.log_vars))
    else:
        log_vars = ["objective"]

    visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"

    # Init the stats struct
    stats = Stats(
        log_vars,
        visdom_env=visdom_env_charts,
        verbose=False,
        visdom_server=cfg.visdom_server,
        visdom_port=cfg.visdom_port,
    )

    # Retrieve the last checkpoint
    if cfg.resume_epoch > 0:
        model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
    else:
        model_path = model_io.find_last_checkpoint(cfg.exp_dir)

    optimizer_state = None
    if model_path is not None:
        logger.info("found previous model %s" % model_path)
        if force_load or cfg.resume:
            logger.info("   -> resuming")
            if load_model_only:
                model_state_dict = torch.load(model_io.get_model_path(model_path))
                stats_load, optimizer_state = None, None
            else:
                model_state_dict, stats_load, optimizer_state = model_io.load_model(
                    model_path
                )

                # Determine if stats should be reset
                if not clear_stats:
                    if stats_load is None:
                        logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
                        last_epoch = model_io.parse_epoch_from_model_path(model_path)
                        logger.info(f"Estimated resume epoch = {last_epoch}")

                        # Reset the stats struct
                        for _ in range(last_epoch + 1):
                            stats.new_epoch()
                        assert last_epoch == stats.epoch
                    else:
                        stats = stats_load

                    # Update stats properties incase it was reset on load
                    stats.visdom_env = visdom_env_charts
                    stats.visdom_server = cfg.visdom_server
                    stats.visdom_port = cfg.visdom_port
                    stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
                    stats.synchronize_logged_vars(log_vars)
                else:
                    logger.info("   -> clearing stats")

            try:
                # TODO: fix on creation of the buffers
                # after the hack above, this will not pass in most cases
                # ... but this is fine for now
                model.load_state_dict(model_state_dict, strict=True)
            except RuntimeError as e:
                logger.error(e)
                logger.info("Cant load state dict in strict mode! -> trying non-strict")
                model.load_state_dict(model_state_dict, strict=False)
            model.log_vars = log_vars
        else:
            logger.info("   -> but not resuming -> starting from scratch")
    elif force_load:
        raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")

    return model, stats, optimizer_state


def init_optimizer(
    model: GenericModel,
    optimizer_state: Optional[Dict[str, Any]],
    last_epoch: int,
219
    breed: str = "adam",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    weight_decay: float = 0.0,
    lr_policy: str = "multistep",
    lr: float = 0.0005,
    gamma: float = 0.1,
    momentum: float = 0.9,
    betas: Tuple[float] = (0.9, 0.999),
    milestones: tuple = (),
    max_epochs: int = 1000,
):
    """
    Initialize the optimizer (optionally from checkpoint state)
    and the learning rate scheduler.

    Args:
        model: The model with optionally loaded weights
        optimizer_state: The state dict for the optimizer. If None
            it has not been loaded from checkpoint
        last_epoch: If the model was loaded from checkpoint this will be the
            number of the last epoch that was saved
        breed: The type of optimizer to use e.g. adam
        weight_decay: The optimizer weight_decay (L2 penalty on model weights)
        lr_policy: The policy to use for learning rate. Currently, only "multistep:
            is supported.
        lr: The value for the initial learning rate
        gamma: Multiplicative factor of learning rate decay
        momentum: Momentum factor for SGD optimizer
        betas: Coefficients used for computing running averages of gradient and its square
            in the Adam optimizer
        milestones: List of increasing epoch indices at which the learning rate is
            modified
        max_epochs: The maximum number of epochs to run the optimizer for

    Returns:
        optimizer: Optimizer module, optionally loaded from checkpoint
        scheduler: Learning rate scheduler module

    Raise:
        ValueError if `breed` or `lr_policy` are not supported.
    """

    # Get the parameters to optimize
    if hasattr(model, "_get_param_groups"):  # use the model function
        p_groups = model._get_param_groups(lr, wd=weight_decay)
    else:
        allprm = [prm for prm in model.parameters() if prm.requires_grad]
        p_groups = [{"params": allprm, "lr": lr}]

    # Intialize the optimizer
    if breed == "sgd":
        optimizer = torch.optim.SGD(
            p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
        )
    elif breed == "adagrad":
        optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
    elif breed == "adam":
        optimizer = torch.optim.Adam(
            p_groups, lr=lr, betas=betas, weight_decay=weight_decay
        )
    else:
        raise ValueError("no such solver type %s" % breed)
    logger.info("  -> solver type = %s" % breed)

    # Load state from checkpoint
    if optimizer_state is not None:
        logger.info("  -> setting loaded optimizer state")
        optimizer.load_state_dict(optimizer_state)

    # Initialize the learning rate scheduler
    if lr_policy == "multistep":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=milestones,
            gamma=gamma,
        )
    else:
        raise ValueError("no such lr policy %s" % lr_policy)

    # When loading from checkpoint, this will make sure that the
    # lr is correctly set even after returning
    for _ in range(last_epoch):
        scheduler.step()

    # Add the max epochs here
    scheduler.max_epochs = max_epochs

    optimizer.zero_grad()
    return optimizer, scheduler


309
310
311
enable_get_default_args(init_optimizer)


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
312
313
314
315
316
317
def trainvalidate(
    model,
    stats,
    epoch,
    loader,
    optimizer,
318
    validation: bool,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    bp_var: str = "objective",
    metric_print_interval: int = 5,
    visualize_interval: int = 100,
    visdom_env_root: str = "trainvalidate",
    clip_grad: float = 0.0,
    device: str = "cuda:0",
    **kwargs,
) -> None:
    """
    This is the main loop for training and evaluation including:
    model forward pass, loss computation, backward pass and visualization.

    Args:
        model: The model module optionally loaded from checkpoint
        stats: The stats struct, also optionally loaded from checkpoint
        epoch: The index of the current epoch
        loader: The dataloader to use for the loop
        optimizer: The optimizer module optionally loaded from checkpoint
        validation: If true, run the loop with the model in eval mode
            and skip the backward pass
        bp_var: The name of the key in the model output `preds` dict which
            should be used as the loss for the backward pass.
        metric_print_interval: The batch interval at which the stats should be
            logged.
        visualize_interval: The batch interval at which the visualizations
            should be plotted
        visdom_env_root: The name of the visdom environment to use for plotting
        clip_grad: Optionally clip the gradient norms.
            If set to a value <=0.0, no clipping
        device: The device on which to run the model.

    Returns:
        None
    """

    if validation:
        model.eval()
        trainmode = "val"
    else:
        model.train()
        trainmode = "train"

    t_start = time.time()

    # get the visdom env name
    visdom_env_imgs = visdom_env_root + "_images_" + trainmode
    viz = vis_utils.get_visdom_connection(
        server=stats.visdom_server,
        port=stats.visdom_port,
    )

    # Iterate through the batches
    n_batches = len(loader)
    for it, batch in enumerate(loader):
        last_iter = it == n_batches - 1

        # move to gpu where possible (in place)
        net_input = batch.to(device)

        # run the forward pass
        if not validation:
            optimizer.zero_grad()
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
        else:
            with torch.no_grad():
                preds = model(
                    **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
                )

        # make sure we dont overwrite something
        assert all(k not in preds for k in net_input.keys())
        # merge everything into one big dict
        preds.update(net_input)

        # update the stats logger
        stats.update(preds, time_start=t_start, stat_set=trainmode)
        assert stats.it[trainmode] == it, "inconsistent stat iteration number!"

        # print textual status update
        if it % metric_print_interval == 0 or last_iter:
            stats.print(stat_set=trainmode, max_it=n_batches)

        # visualize results
        if visualize_interval > 0 and it % visualize_interval == 0:
            prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"

            model.visualize(
                viz,
                visdom_env_imgs,
                preds,
                prefix,
            )

        # optimizer step
        if not validation:
            loss = preds[bp_var]
            assert torch.isfinite(loss).all(), "Non-finite loss!"
            # backprop
            loss.backward()
            if clip_grad > 0.0:
                # Optionally clip the gradient norms.
                total_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), clip_grad
                )
                if total_norm > clip_grad:
                    logger.info(
                        f"Clipping gradient: {total_norm}"
                        + f" with coef {clip_grad / total_norm}."
                    )

            optimizer.step()


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
432
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    """
    Entry point to run the training and validation loops
    based on the specified config file.
    """

    # set the debug mode
    if cfg.detect_anomaly:
        logger.info("Anomaly detection!")
    torch.autograd.set_detect_anomaly(cfg.detect_anomaly)

    # create the output folder
    os.makedirs(cfg.exp_dir, exist_ok=True)
    _seed_all_random_engines(cfg.seed)
    remove_unused_components(cfg)

    # dump the exp config to the exp dir
    try:
        cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
        OmegaConf.save(config=cfg, f=cfg_filename)
    except PermissionError:
        warnings.warn("Cant dump config due to insufficient permissions!")

    # setup datasets
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
456
457
458
    datasource = ImplicitronDataSource(**cfg.data_source_args)
    datasets, dataloaders = datasource.get_datasets_and_dataloaders()
    task = datasource.get_task()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
459
460
461
462
463
464
465
466
467
468

    # init the model
    model, stats, optimizer_state = init_model(cfg)
    start_epoch = stats.epoch + 1

    # move model to gpu
    model.to(device)

    # only run evaluation on the test dataloader
    if cfg.eval_only:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
469
        _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        return

    # init the optimizer
    optimizer, scheduler = init_optimizer(
        model,
        optimizer_state=optimizer_state,
        last_epoch=start_epoch,
        **cfg.solver_args,
    )

    # check the scheduler and stats have been initialized correctly
    assert scheduler.last_epoch == stats.epoch + 1
    assert scheduler.last_epoch == start_epoch

    past_scheduler_lrs = []
    # loop through epochs
    for epoch in range(start_epoch, cfg.solver_args.max_epochs):
        # automatic new_epoch and plotting of stats at every epoch start
        with stats:

            # Make sure to re-seed random generators to ensure reproducibility
            # even after restart.
            _seed_all_random_engines(cfg.seed + epoch)

            cur_lr = float(scheduler.get_last_lr()[-1])
            logger.info(f"scheduler lr = {cur_lr:1.2e}")
            past_scheduler_lrs.append(cur_lr)

            # train loop
            trainvalidate(
                model,
                stats,
                epoch,
503
                dataloaders.train,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
504
505
506
507
508
509
510
511
                optimizer,
                False,
                visdom_env_root=vis_utils.get_visdom_env(cfg),
                device=device,
                **cfg,
            )

            # val loop (optional)
512
            if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
513
514
515
516
                trainvalidate(
                    model,
                    stats,
                    epoch,
517
                    dataloaders.val,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
518
519
520
521
522
523
524
525
526
                    optimizer,
                    True,
                    visdom_env_root=vis_utils.get_visdom_env(cfg),
                    device=device,
                    **cfg,
                )

            # eval loop (optional)
            if (
527
                dataloaders.test is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
528
529
530
                and cfg.test_interval > 0
                and epoch % cfg.test_interval == 0
            ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
531
                _run_eval(model, stats, dataloaders.test, task, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550

            assert stats.epoch == epoch, "inconsistent stats!"

            # delete previous models if required
            # save model
            if cfg.store_checkpoints:
                if cfg.store_checkpoints_purge > 0:
                    for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
                        model_io.purge_epoch(cfg.exp_dir, prev_epoch)
                outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
                model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)

            scheduler.step()

            new_lr = float(scheduler.get_last_lr()[-1])
            if new_lr != cur_lr:
                logger.info(f"LR change! {cur_lr} -> {new_lr}")

    if cfg.test_when_finished:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
551
        _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
552
553


554
def _eval_and_dump(
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
555
556
    cfg,
    task: Task,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
557
    datasets: DatasetMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
558
    dataloaders: DataLoaderMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
559
560
561
    model,
    stats,
    device,
562
) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
563
564
565
566
567
    """
    Run the evaluation loop with the test data loader and
    save the predictions to the `exp_dir`.
    """

568
569
570
    dataloader = dataloaders.test

    if dataloader is None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
571
        raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
572

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
573
    if task == Task.SINGLE_SEQUENCE:
574
575
576
577
578
        if datasets.train is None:
            raise ValueError("train dataset must be provided")
        all_source_cameras = _get_all_source_cameras(datasets.train)
    else:
        all_source_cameras = None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
579
    results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

    # add the evaluation epoch to the results
    for r in results:
        r["eval_epoch"] = int(stats.epoch)

    logger.info("Evaluation results")
    evaluate.pretty_print_nvs_metrics(results)

    with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
        json.dump(results, f)


def _get_eval_frame_data(frame_data):
    """
    Masks the unknown image data to make sure we cannot use it at model evaluation time.
    """
    frame_data_for_eval = copy.deepcopy(frame_data)
    is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
        frame_data.image_rgb
    )[:, None, None, None]
    for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
        value_masked = getattr(frame_data_for_eval, k).clone() * is_known
        setattr(frame_data_for_eval, k, value_masked)
    return frame_data_for_eval


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
606
def _run_eval(model, all_source_cameras, loader, task: Task, device):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    """
    Run the evaluation loop on the test dataloader
    """
    lpips_model = lpips.LPIPS(net="vgg")
    lpips_model = lpips_model.to(device)

    model.eval()

    per_batch_eval_results = []
    logger.info("Evaluating model ...")
    for frame_data in tqdm.tqdm(loader):
        frame_data = frame_data.to(device)

        # mask out the unknown images so that the model does not see them
        frame_data_for_eval = _get_eval_frame_data(frame_data)

        with torch.no_grad():
            preds = model(
                **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
            )
627
            implicitron_render = copy.deepcopy(preds["implicitron_render"])
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
628
629
630
            per_batch_eval_results.append(
                evaluate.eval_batch(
                    frame_data,
631
                    implicitron_render,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
                    bg_color="black",
                    lpips_model=lpips_model,
                    source_cameras=all_source_cameras,
                )
            )

    _, category_result = evaluate.summarize_nvs_eval_results(
        per_batch_eval_results, task
    )

    return category_result["results"]


def _get_all_source_cameras(
646
    dataset: JsonIndexDataset,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    num_workers: int = 8,
) -> CamerasBase:
    """
    Load and return all the source cameras in the training dataset
    """

    all_frame_data = next(
        iter(
            torch.utils.data.DataLoader(
                dataset,
                shuffle=False,
                batch_size=len(dataset),
                num_workers=num_workers,
                collate_fn=FrameData.collate,
            )
        )
    )

    is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
    source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
    return source_cameras


def _seed_all_random_engines(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)


676
class ExperimentConfig(Configurable):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
677
    generic_model_args: DictConfig = get_default_args_field(GenericModel)
678
    solver_args: DictConfig = get_default_args_field(init_optimizer)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
679
    data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    architecture: str = "generic"
    detect_anomaly: bool = False
    eval_only: bool = False
    exp_dir: str = "./data/default_experiment/"
    exp_idx: int = 0
    gpu_idx: int = 0
    metric_print_interval: int = 5
    resume: bool = True
    resume_epoch: int = -1
    seed: int = 0
    store_checkpoints: bool = True
    store_checkpoints_purge: int = 1
    test_interval: int = -1
    test_when_finished: bool = False
    validation_interval: int = 1
    visdom_env: str = ""
    visdom_port: int = 8097
    visdom_server: str = "http://127.0.0.1"
    visualize_interval: int = 1000
    clip_grad: float = 0.0

    hydra: dict = field(
        default_factory=lambda: {
            "run": {"dir": "."},  # Make hydra not change the working dir.
            "output_subdir": None,  # disable storing the .hydra logs
        }
    )


709
710
expand_args_fields(ExperimentConfig)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
711
712
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728


@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
    # Set the device
    device = "cpu"
    if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
        device = f"cuda:{cfg.gpu_idx}"
    logger.info(f"Running experiment on device: {device}")
    run_training(cfg, device)


if __name__ == "__main__":
    experiment()