engine.py 13.9 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional, Union

import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

logger = logging.get_logger(__name__)


class PipelineEngine(ABC):
    def __init__(self):
        self.nb_microbatches: Optional[int] = None
        pass

    def forward(
        self,
        context: ContextManagers,
        state: PipelineTrainBatchState,
        micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]],
        model: torch_nn.Module,
    ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
        # Increment the number of backwards
        state.nb_forwards += 1
        log_rank(
            f"Forward micro batch id: {state.nb_forwards}",
            logger=logger,
            level=logging.DEBUG,
        )

        # IMPORTANT as it's basically the context manager storing all the intermediary activations
        state.new_micro_batch_forward()
        with context:
            output = model(**micro_batch)

        # We make `output` a dict
        if not isinstance(output, dict):
            output = {"loss": output}

        # We normalize our loss
        if not isinstance(output["loss"], TensorPointer):
            output["loss"] = output["loss"] / self.nb_microbatches

        # Add output as activations that require backward pass
        if not isinstance(output["loss"], TensorPointer):
            assert output["loss"].requires_grad
            state.register_activation_requiring_backward(output["loss"])
        return output

    @staticmethod
    def _get_fwd_context(model: torch_nn.Module):
        is_ddp = isinstance(model, DistributedDataParallel)
        # We never to trigger a DDP sync in the next backward pass
        context = ContextManagers([model.no_sync()] if is_ddp else [])
        return context

    def backward(
        self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator]
    ):
        # Increment the number of backwards
        state.nb_backwards += 1
        log_rank(
            f"Backward micro batch id: {state.nb_forwards}",
            logger=logger,
            level=logging.DEBUG,
        )
        # Go backward entirely
        activations = state.pop_last_activations_requiring_backward()
        if len(activations) == 0:
            return

        with context:
            if grad_accumulator is None:
                sum(activations).backward()
            else:
                grad_accumulator.backward(sum(activations))

        # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
        # with context:
        #     if grad_accumulator is None:
        #         for activation in reversed(activations): #TODO @nouamane: need to bwd only 2nd chunk
        #             activation.backward()
        #     else:
        #         for activation in reversed(activations):
        #             grad_accumulator.backward(activation)

    def _get_bwd_context(
        self,
        model: torch_nn.Module,
        nb_backwards: int,
        grad_accumulator: Optional[GradientAccumulator],
    ):
        assert (
            self.nb_microbatches is not None
        ), "You must call `train_batch_iter` first and set `self.nb_microbatches`"
        is_ddp = isinstance(model, DistributedDataParallel)
        context_list = []
        if is_ddp:
            if grad_accumulator is not None and nb_backwards < self.nb_microbatches - 1:
                context_list.append(grad_accumulator.no_sync())  # Prevents grad accumulator from syncing
            if nb_backwards == self.nb_microbatches - 1:
                # Triggers DDP to sync gradients in the next backward pass
                context_list.append(ddp_trigger_sync_in_bwd(model_ddp=model))
        context = ContextManagers(context_list)
        return context

    @abstractmethod
    def train_batch_iter(
        self,
        model: torch_nn.Module,
        pg: ProcessGroup,
        batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
        nb_microbatches: int,
        grad_accumulator: Optional[GradientAccumulator],
    ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        """If model returns tensor, we use it as a loss to backpropagate. If model returns a dict, we assume that the key "loss" is the loss to backpropagate."""
        ...

    @torch.inference_mode()
    def validate_batch_iter(
        self,
        model: torch_nn.Module,
        batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
        nb_microbatches: int,
    ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        # Assign a new state for the current batch
        state = PipelineTrainBatchState()  # TODO: do i need state?
        self.nb_microbatches = nb_microbatches

        outputs = []

        with attach_pipeline_state_to_model(model=model, pipeline_state=state):
            # All forward
            for micro_batch in batch:
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)

        return outputs


class AllForwardAllBackwardPipelineEngine(PipelineEngine):
    def __init__(self):
        super().__init__()

    def train_batch_iter(
        self,
        model: torch_nn.Module,
        pg: ProcessGroup,
        batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
        nb_microbatches: int,
        grad_accumulator: Optional[GradientAccumulator],
    ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        # Assign a new state for the current batch
        state = PipelineTrainBatchState()
        self.nb_microbatches = nb_microbatches

        outputs = []

        with attach_pipeline_state_to_model(model=model, pipeline_state=state):
            # All forward
            for micro_batch in batch:
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)

            # All backward
            for _ in range(len(state.microbatches_activations_requiring_backward)):
                context = self._get_bwd_context(
                    model=model,
                    nb_backwards=state.nb_backwards,
                    grad_accumulator=grad_accumulator,
                )
                self.backward(context=context, state=state, grad_accumulator=grad_accumulator)

                for _ in range(len(state.microbatches_grads_to_send)):
                    send_grads = state.microbatches_grads_to_send.popleft()
                    # Execute
                    send_grads()
            # Make sure that micro batches are all fully consumed
            state.check_buffers_empty()

            return outputs


class OneForwardOneBackwardPipelineEngine(PipelineEngine):
    def __init__(self):
        super().__init__()

    def train_batch_iter(
        self,
        model: torch_nn.Module,
        pg: ProcessGroup,
        batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
        nb_microbatches: int,
        grad_accumulator: Optional[GradientAccumulator],
    ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
        """Check https://arxiv.org/abs/2104.04473 for diagrams for the pipeline engine"""
        self.nb_microbatches = nb_microbatches
        assert (
            self.nb_microbatches >= pg.size() - 1
        ), f"Number of microbatches ({self.nb_microbatches}) must be at least PP_SIZE-1={pg.size() - 1} when using the OneForwardOneBackwardPipelineEngine"

        state = PipelineTrainBatchState()

        outputs = []
        batch = iter(batch)

        current_pp_rank = dist.get_rank(pg)

        with attach_pipeline_state_to_model(model=model, pipeline_state=state):
            # Init
            for _ in range(pg.size() - current_pp_rank - 1):
                micro_batch = next(batch)
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)

                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Send tensors
                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_activations_to_send)):
                    send_activation = state.microbatches_activations_to_send.popleft()
                    # Execute
                    send_activation()

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)

            for micro_batch in batch:
                context = self._get_fwd_context(model=model)
                output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)

                # We make `output` a dict
                if not isinstance(output, dict):
                    output = {"loss": output}

                # Store the loss for each microbatch
                if not isinstance(output["loss"], TensorPointer):
                    output = {k: v.detach() for k, v in output.items()}
                outputs.append(output)

                # One backward
                context = self._get_bwd_context(
                    model=model,
                    nb_backwards=state.nb_backwards,
                    grad_accumulator=grad_accumulator,
                )
                self.backward(context=context, state=state, grad_accumulator=grad_accumulator)

            # Check figure in paper: The remain blocks are all backward and there is only `pg.size() - current_pp_rank - 1` blocks left
            assert len(state.microbatches_activations_requiring_backward) == pg.size() - current_pp_rank - 1
            # No more activation to send/recv
            assert (
                len(state.microbatches_activations_to_send) == 0
            ), f"There are activations left for me to send still: {len(state.microbatches_activations_to_send)}"
            assert (
                len(state.microbatches_activations_to_recv) == 0
            ), f"There are activations left for me to recv still: {len(state.microbatches_activations_to_recv)}"

            # Close: compute backward for the rest
            # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
            for _ in range(len(state.microbatches_grads_to_send)):
                send_grads = state.microbatches_grads_to_send.popleft()
                # Execute
                send_grads()
            for _ in range(len(state.microbatches_activations_requiring_backward)):
                context = self._get_bwd_context(
                    model=model,
                    nb_backwards=state.nb_backwards,
                    grad_accumulator=grad_accumulator,
                )
                self.backward(context=context, state=state, grad_accumulator=grad_accumulator)

                # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
                for _ in range(len(state.microbatches_grads_to_send)):
                    send_grads = state.microbatches_grads_to_send.popleft()
                    # Execute
                    send_grads()

            # Make sure that micro batches are all fully consumed
            state.check_buffers_empty()

        return outputs