trainer.py 1.04 KB
Newer Older
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import pytorch_lightning as pl
from .accelerator import BypassAccelerator

class Trainer(pl.Trainer):
    """
    Trainer for cross-graph optimization.

    Parameters
    ----------
    use_cgo : bool
        Whether cross-graph optimization (CGO) is used.
        If it is True, CGO will manage device placement.
        Any device placement from pytorch lightning will be bypassed.
        default: False
    trainer_kwargs : dict
        Optional keyword arguments passed to trainer. See
20
        `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
21
22
23
24
25
26
    """

    def __init__(self, use_cgo=False, **trainer_kwargs):
        if use_cgo:
            if "accelerator" in trainer_kwargs:
                raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
27
            trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu', **trainer_kwargs)
28
29

        super().__init__(**trainer_kwargs)