pipe.py 2.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""The AMPnetPipe interface."""

from typing import Any

from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

14
from fairscale.nn.pipe import AsyncPipe
15
16
17
18
19
20

from .ampnet import AsyncAMPnetEventLoop

__all__ = ["AMPnetPipe"]


21
class AMPnetPipe(AsyncPipe):
22
    """
23
        AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        which avoids the bubble issue, by using stale weights and gradients.
        The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
    """

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def interleave(
        self,
        lm_dataloader: DataLoader,
        criterion: nn.Module,
        optimizer: Optimizer,
        transform_logger_object: Any,
        min_update_interval: int = 1,
        weight_prediction: bool = False,
    ) -> None:

41
        partitions = self.partitions
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        n = len(partitions)

        # AMPnet implementation doesn't handle skip_trackers!

        assert self.group
        rank = self.group.rank()

        transport = self.pipeline.transport  # type: ignore
        checkpoint_stop = self.pipeline.checkpoint_stop  # type: ignore
        ampnet_event_loop = AsyncAMPnetEventLoop(
            partitions,
            self.group,
            transport,
            min_update_interval,
            weight_prediction,
            checkpoint_stop,
            self.input_device,
        )

        if rank == 0:
            ampnet_event_loop.event_loop_head_across_minibatches(
                lm_dataloader, criterion, optimizer, transform_logger_object
            )
        elif self.final_stage:
            ampnet_event_loop.event_loop_tail_across_minibatches(
                lm_dataloader, criterion, optimizer, transform_logger_object
            )
        else:
            ampnet_event_loop.event_loop_across_minibatches(
                lm_dataloader, criterion, optimizer, transform_logger_object
            )