experiment.py 23.5 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

""""
This file is the entry point for launching experiments with Implicitron.

Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
    and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
    pass, visualizations and metric printing

Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:

```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```

To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.

Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
  - Checkpoints:
        Model, optimizer and stats are stored in the directory
        named by the `exp_dir` key from the config file / CLI parameters.
  - Stats
        Stats are logged and plotted to the file "train_stats.pdf" in the
        same directory. The stats are also saved as part of the checkpoint file.
  - Visualizations
        Prredictions are plotted to a visdom server running at the
        port specified by the `visdom_server` and `visdom_port` keys in the
        config file.

"""

import copy
import json
import logging
import os
import random
import time
import warnings
56
from dataclasses import field
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
57
58
59
60
61
62
63
64
65
66
from typing import Any, Dict, Optional, Tuple

import hydra
import lpips
import numpy as np
import torch
import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
67
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
70
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
71
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
72
73
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
74
    Configurable,
75
    enable_get_default_args,
76
    expand_args_fields,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
77
78
79
80
81
82
    get_default_args_field,
    remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase

83

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
logger = logging.getLogger(__name__)

if version.parse(hydra.__version__) < version.Version("1.1"):
    raise ValueError(
        f"Hydra version {hydra.__version__} is too old."
        " (Implicitron requires version 1.1 or later.)"
    )

try:
    # only makes sense in FAIR cluster
    import pytorch3d.implicitron.fair_cluster.slurm  # noqa: F401
except ModuleNotFoundError:
    pass


def init_model(
    cfg: DictConfig,
    force_load: bool = False,
    clear_stats: bool = False,
    load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
    """
    Returns an instance of `GenericModel`.

    If `cfg.resume` is set or `force_load` is true,
    attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
    will return the model with initial weights, unless `force_load` is passed,
    in which case a FileNotFoundError is raised.

    Args:
        force_load: If true, force load model from checkpoint even if
            cfg.resume is false.
        clear_stats: If true, clear the stats object loaded from checkpoint
        load_model_only: If true, load only the model weights from checkpoint
            and do not load the state of the optimizer and stats.

    Returns:
        model: The model with optionally loaded weights from checkpoint
        stats: The stats structure (optionally loaded from checkpoint)
        optimizer_state: The optimizer state dict containing
            `state` and `param_groups` keys (optionally loaded from checkpoint)

    Raise:
        FileNotFoundError if `force_load` is passed but checkpoint is not found.
    """

    # Initialize the model
    if cfg.architecture == "generic":
        model = GenericModel(**cfg.generic_model_args)
    else:
        raise ValueError(f"No such arch {cfg.architecture}.")

    # Determine the network outputs that should be logged
    if hasattr(model, "log_vars"):
        log_vars = copy.deepcopy(list(model.log_vars))
    else:
        log_vars = ["objective"]

    visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"

    # Init the stats struct
    stats = Stats(
        log_vars,
        visdom_env=visdom_env_charts,
        verbose=False,
        visdom_server=cfg.visdom_server,
        visdom_port=cfg.visdom_port,
    )

    # Retrieve the last checkpoint
    if cfg.resume_epoch > 0:
        model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
    else:
        model_path = model_io.find_last_checkpoint(cfg.exp_dir)

    optimizer_state = None
    if model_path is not None:
        logger.info("found previous model %s" % model_path)
        if force_load or cfg.resume:
            logger.info("   -> resuming")
            if load_model_only:
                model_state_dict = torch.load(model_io.get_model_path(model_path))
                stats_load, optimizer_state = None, None
            else:
                model_state_dict, stats_load, optimizer_state = model_io.load_model(
                    model_path
                )

                # Determine if stats should be reset
                if not clear_stats:
                    if stats_load is None:
                        logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
                        last_epoch = model_io.parse_epoch_from_model_path(model_path)
                        logger.info(f"Estimated resume epoch = {last_epoch}")

                        # Reset the stats struct
                        for _ in range(last_epoch + 1):
                            stats.new_epoch()
                        assert last_epoch == stats.epoch
                    else:
                        stats = stats_load

                    # Update stats properties incase it was reset on load
                    stats.visdom_env = visdom_env_charts
                    stats.visdom_server = cfg.visdom_server
                    stats.visdom_port = cfg.visdom_port
                    stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
                    stats.synchronize_logged_vars(log_vars)
                else:
                    logger.info("   -> clearing stats")

            try:
                # TODO: fix on creation of the buffers
                # after the hack above, this will not pass in most cases
                # ... but this is fine for now
                model.load_state_dict(model_state_dict, strict=True)
            except RuntimeError as e:
                logger.error(e)
                logger.info("Cant load state dict in strict mode! -> trying non-strict")
                model.load_state_dict(model_state_dict, strict=False)
            model.log_vars = log_vars
        else:
            logger.info("   -> but not resuming -> starting from scratch")
    elif force_load:
        raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")

    return model, stats, optimizer_state


def init_optimizer(
    model: GenericModel,
    optimizer_state: Optional[Dict[str, Any]],
    last_epoch: int,
217
    breed: str = "adam",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
218
219
220
221
222
    weight_decay: float = 0.0,
    lr_policy: str = "multistep",
    lr: float = 0.0005,
    gamma: float = 0.1,
    momentum: float = 0.9,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
223
    betas: Tuple[float, ...] = (0.9, 0.999),
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    milestones: tuple = (),
    max_epochs: int = 1000,
):
    """
    Initialize the optimizer (optionally from checkpoint state)
    and the learning rate scheduler.

    Args:
        model: The model with optionally loaded weights
        optimizer_state: The state dict for the optimizer. If None
            it has not been loaded from checkpoint
        last_epoch: If the model was loaded from checkpoint this will be the
            number of the last epoch that was saved
        breed: The type of optimizer to use e.g. adam
        weight_decay: The optimizer weight_decay (L2 penalty on model weights)
        lr_policy: The policy to use for learning rate. Currently, only "multistep:
            is supported.
        lr: The value for the initial learning rate
        gamma: Multiplicative factor of learning rate decay
        momentum: Momentum factor for SGD optimizer
        betas: Coefficients used for computing running averages of gradient and its square
            in the Adam optimizer
        milestones: List of increasing epoch indices at which the learning rate is
            modified
        max_epochs: The maximum number of epochs to run the optimizer for

    Returns:
        optimizer: Optimizer module, optionally loaded from checkpoint
        scheduler: Learning rate scheduler module

    Raise:
        ValueError if `breed` or `lr_policy` are not supported.
    """

    # Get the parameters to optimize
    if hasattr(model, "_get_param_groups"):  # use the model function
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
260
        # pyre-ignore[29]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        p_groups = model._get_param_groups(lr, wd=weight_decay)
    else:
        allprm = [prm for prm in model.parameters() if prm.requires_grad]
        p_groups = [{"params": allprm, "lr": lr}]

    # Intialize the optimizer
    if breed == "sgd":
        optimizer = torch.optim.SGD(
            p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
        )
    elif breed == "adagrad":
        optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
    elif breed == "adam":
        optimizer = torch.optim.Adam(
            p_groups, lr=lr, betas=betas, weight_decay=weight_decay
        )
    else:
        raise ValueError("no such solver type %s" % breed)
    logger.info("  -> solver type = %s" % breed)

    # Load state from checkpoint
    if optimizer_state is not None:
        logger.info("  -> setting loaded optimizer state")
        optimizer.load_state_dict(optimizer_state)

    # Initialize the learning rate scheduler
    if lr_policy == "multistep":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=milestones,
            gamma=gamma,
        )
    else:
        raise ValueError("no such lr policy %s" % lr_policy)

    # When loading from checkpoint, this will make sure that the
    # lr is correctly set even after returning
    for _ in range(last_epoch):
        scheduler.step()

    optimizer.zero_grad()
    return optimizer, scheduler


305
306
307
enable_get_default_args(init_optimizer)


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
308
309
310
311
312
313
def trainvalidate(
    model,
    stats,
    epoch,
    loader,
    optimizer,
314
    validation: bool,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    bp_var: str = "objective",
    metric_print_interval: int = 5,
    visualize_interval: int = 100,
    visdom_env_root: str = "trainvalidate",
    clip_grad: float = 0.0,
    device: str = "cuda:0",
    **kwargs,
) -> None:
    """
    This is the main loop for training and evaluation including:
    model forward pass, loss computation, backward pass and visualization.

    Args:
        model: The model module optionally loaded from checkpoint
        stats: The stats struct, also optionally loaded from checkpoint
        epoch: The index of the current epoch
        loader: The dataloader to use for the loop
        optimizer: The optimizer module optionally loaded from checkpoint
        validation: If true, run the loop with the model in eval mode
            and skip the backward pass
        bp_var: The name of the key in the model output `preds` dict which
            should be used as the loss for the backward pass.
        metric_print_interval: The batch interval at which the stats should be
            logged.
        visualize_interval: The batch interval at which the visualizations
            should be plotted
        visdom_env_root: The name of the visdom environment to use for plotting
        clip_grad: Optionally clip the gradient norms.
            If set to a value <=0.0, no clipping
        device: The device on which to run the model.

    Returns:
        None
    """

    if validation:
        model.eval()
        trainmode = "val"
    else:
        model.train()
        trainmode = "train"

    t_start = time.time()

    # get the visdom env name
    visdom_env_imgs = visdom_env_root + "_images_" + trainmode
    viz = vis_utils.get_visdom_connection(
        server=stats.visdom_server,
        port=stats.visdom_port,
    )

    # Iterate through the batches
    n_batches = len(loader)
    for it, batch in enumerate(loader):
        last_iter = it == n_batches - 1

        # move to gpu where possible (in place)
        net_input = batch.to(device)

        # run the forward pass
        if not validation:
            optimizer.zero_grad()
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
        else:
            with torch.no_grad():
                preds = model(
                    **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
                )

        # make sure we dont overwrite something
        assert all(k not in preds for k in net_input.keys())
        # merge everything into one big dict
        preds.update(net_input)

        # update the stats logger
        stats.update(preds, time_start=t_start, stat_set=trainmode)
        assert stats.it[trainmode] == it, "inconsistent stat iteration number!"

        # print textual status update
        if it % metric_print_interval == 0 or last_iter:
            stats.print(stat_set=trainmode, max_it=n_batches)

        # visualize results
        if visualize_interval > 0 and it % visualize_interval == 0:
            prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"

            model.visualize(
                viz,
                visdom_env_imgs,
                preds,
                prefix,
            )

        # optimizer step
        if not validation:
            loss = preds[bp_var]
            assert torch.isfinite(loss).all(), "Non-finite loss!"
            # backprop
            loss.backward()
            if clip_grad > 0.0:
                # Optionally clip the gradient norms.
                total_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), clip_grad
                )
                if total_norm > clip_grad:
                    logger.info(
                        f"Clipping gradient: {total_norm}"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
422
                        + f" with coef {clip_grad / float(total_norm)}."
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
423
424
425
426
427
                    )

            optimizer.step()


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
428
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    """
    Entry point to run the training and validation loops
    based on the specified config file.
    """

    # set the debug mode
    if cfg.detect_anomaly:
        logger.info("Anomaly detection!")
    torch.autograd.set_detect_anomaly(cfg.detect_anomaly)

    # create the output folder
    os.makedirs(cfg.exp_dir, exist_ok=True)
    _seed_all_random_engines(cfg.seed)
    remove_unused_components(cfg)

    # dump the exp config to the exp dir
    try:
        cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
        OmegaConf.save(config=cfg, f=cfg_filename)
    except PermissionError:
        warnings.warn("Cant dump config due to insufficient permissions!")

    # setup datasets
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
452
453
454
    datasource = ImplicitronDataSource(**cfg.data_source_args)
    datasets, dataloaders = datasource.get_datasets_and_dataloaders()
    task = datasource.get_task()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
455
    all_train_cameras = datasource.get_all_train_cameras()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
456
457
458
459
460
461
462
463
464
465

    # init the model
    model, stats, optimizer_state = init_model(cfg)
    start_epoch = stats.epoch + 1

    # move model to gpu
    model.to(device)

    # only run evaluation on the test dataloader
    if cfg.eval_only:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
466
467
468
469
470
471
472
473
474
475
        _eval_and_dump(
            cfg,
            task,
            all_train_cameras,
            datasets,
            dataloaders,
            model,
            stats,
            device=device,
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        return

    # init the optimizer
    optimizer, scheduler = init_optimizer(
        model,
        optimizer_state=optimizer_state,
        last_epoch=start_epoch,
        **cfg.solver_args,
    )

    # check the scheduler and stats have been initialized correctly
    assert scheduler.last_epoch == stats.epoch + 1
    assert scheduler.last_epoch == start_epoch

    past_scheduler_lrs = []
    # loop through epochs
    for epoch in range(start_epoch, cfg.solver_args.max_epochs):
        # automatic new_epoch and plotting of stats at every epoch start
        with stats:

            # Make sure to re-seed random generators to ensure reproducibility
            # even after restart.
            _seed_all_random_engines(cfg.seed + epoch)

            cur_lr = float(scheduler.get_last_lr()[-1])
            logger.info(f"scheduler lr = {cur_lr:1.2e}")
            past_scheduler_lrs.append(cur_lr)

            # train loop
            trainvalidate(
                model,
                stats,
                epoch,
509
                dataloaders.train,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
510
511
512
513
514
515
516
517
                optimizer,
                False,
                visdom_env_root=vis_utils.get_visdom_env(cfg),
                device=device,
                **cfg,
            )

            # val loop (optional)
518
            if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
519
520
521
522
                trainvalidate(
                    model,
                    stats,
                    epoch,
523
                    dataloaders.val,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
524
525
526
527
528
529
530
531
532
                    optimizer,
                    True,
                    visdom_env_root=vis_utils.get_visdom_env(cfg),
                    device=device,
                    **cfg,
                )

            # eval loop (optional)
            if (
533
                dataloaders.test is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
534
535
536
                and cfg.test_interval > 0
                and epoch % cfg.test_interval == 0
            ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
537
538
539
                _run_eval(
                    model, all_train_cameras, dataloaders.test, task, device=device
                )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

            assert stats.epoch == epoch, "inconsistent stats!"

            # delete previous models if required
            # save model
            if cfg.store_checkpoints:
                if cfg.store_checkpoints_purge > 0:
                    for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
                        model_io.purge_epoch(cfg.exp_dir, prev_epoch)
                outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
                model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)

            scheduler.step()

            new_lr = float(scheduler.get_last_lr()[-1])
            if new_lr != cur_lr:
                logger.info(f"LR change! {cur_lr} -> {new_lr}")

    if cfg.test_when_finished:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
559
560
561
562
563
564
565
566
567
568
        _eval_and_dump(
            cfg,
            task,
            all_train_cameras,
            datasets,
            dataloaders,
            model,
            stats,
            device=device,
        )
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
569
570


571
def _eval_and_dump(
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
572
573
    cfg,
    task: Task,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
574
    all_train_cameras: Optional[CamerasBase],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
575
    datasets: DatasetMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
576
    dataloaders: DataLoaderMap,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
577
578
579
    model,
    stats,
    device,
580
) -> None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
581
582
583
584
585
    """
    Run the evaluation loop with the test data loader and
    save the predictions to the `exp_dir`.
    """

586
587
588
    dataloader = dataloaders.test

    if dataloader is None:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
589
        raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
590

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
591
    results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617

    # add the evaluation epoch to the results
    for r in results:
        r["eval_epoch"] = int(stats.epoch)

    logger.info("Evaluation results")
    evaluate.pretty_print_nvs_metrics(results)

    with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
        json.dump(results, f)


def _get_eval_frame_data(frame_data):
    """
    Masks the unknown image data to make sure we cannot use it at model evaluation time.
    """
    frame_data_for_eval = copy.deepcopy(frame_data)
    is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
        frame_data.image_rgb
    )[:, None, None, None]
    for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
        value_masked = getattr(frame_data_for_eval, k).clone() * is_known
        setattr(frame_data_for_eval, k, value_masked)
    return frame_data_for_eval


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
618
def _run_eval(model, all_train_cameras, loader, task: Task, device):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    """
    Run the evaluation loop on the test dataloader
    """
    lpips_model = lpips.LPIPS(net="vgg")
    lpips_model = lpips_model.to(device)

    model.eval()

    per_batch_eval_results = []
    logger.info("Evaluating model ...")
    for frame_data in tqdm.tqdm(loader):
        frame_data = frame_data.to(device)

        # mask out the unknown images so that the model does not see them
        frame_data_for_eval = _get_eval_frame_data(frame_data)

        with torch.no_grad():
            preds = model(
                **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
            )
639
            implicitron_render = copy.deepcopy(preds["implicitron_render"])
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
640
641
642
            per_batch_eval_results.append(
                evaluate.eval_batch(
                    frame_data,
643
                    implicitron_render,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
644
645
                    bg_color="black",
                    lpips_model=lpips_model,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
646
                    source_cameras=all_train_cameras,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
                )
            )

    _, category_result = evaluate.summarize_nvs_eval_results(
        per_batch_eval_results, task
    )

    return category_result["results"]


def _seed_all_random_engines(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)


663
class ExperimentConfig(Configurable):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
664
    generic_model_args: DictConfig = get_default_args_field(GenericModel)
665
    solver_args: DictConfig = get_default_args_field(init_optimizer)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
666
    data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    architecture: str = "generic"
    detect_anomaly: bool = False
    eval_only: bool = False
    exp_dir: str = "./data/default_experiment/"
    exp_idx: int = 0
    gpu_idx: int = 0
    metric_print_interval: int = 5
    resume: bool = True
    resume_epoch: int = -1
    seed: int = 0
    store_checkpoints: bool = True
    store_checkpoints_purge: int = 1
    test_interval: int = -1
    test_when_finished: bool = False
    validation_interval: int = 1
    visdom_env: str = ""
    visdom_port: int = 8097
    visdom_server: str = "http://127.0.0.1"
    visualize_interval: int = 1000
    clip_grad: float = 0.0

    hydra: dict = field(
        default_factory=lambda: {
            "run": {"dir": "."},  # Make hydra not change the working dir.
            "output_subdir": None,  # disable storing the .hydra logs
        }
    )


696
697
expand_args_fields(ExperimentConfig)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
698
699
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715


@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
    # Set the device
    device = "cpu"
    if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
        device = f"cuda:{cfg.gpu_idx}"
    logger.info(f"Running experiment on device: {device}")
    run_training(cfg, device)


if __name__ == "__main__":
    experiment()