state.py 12.1 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import collections
import dataclasses
from abc import ABC, abstractmethod
from typing import List

import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.p2p import P2P

logger = logging.get_logger(__name__)


@dataclasses.dataclass
class SendActivation:
    activation: torch.Tensor
    to_rank: int
    p2p: P2P

    def __call__(self):
        self.p2p.send_tensors([self.activation], to_rank=self.to_rank)


@dataclasses.dataclass
class RecvActivation:
    from_rank: int
    p2p: P2P

    def __call__(self) -> torch.Tensor:
        return self.p2p.recv_tensors(num_tensors=1, from_rank=self.from_rank)[0]


@dataclasses.dataclass
class SendGrad:
    grad: torch.Tensor
    to_rank: int
    p2p: P2P

    def __call__(self):
        self.p2p.send_tensors([self.grad], to_rank=self.to_rank)


@dataclasses.dataclass
class RecvGrad:
    from_rank: int
    p2p: P2P

    def __call__(self) -> torch.Tensor:
        return self.p2p.recv_tensors(num_tensors=1, from_rank=self.from_rank)[0]


class PipelineBatchState(ABC):
    activations_buffer = collections.deque()

    @abstractmethod
    def register_activation_requiring_backward(self, activation: torch.Tensor):
        ...

    @abstractmethod
    def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
        ...

    @abstractmethod
    def register_recv_activation(self, from_rank: int, p2p: P2P):
        ...

    @abstractmethod
    def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
        ...

    @abstractmethod
    def register_recv_grad(self, from_rank: int, p2p: P2P):
        ...

    @abstractmethod
    def run_communication(self, send_only_activation: bool = False):
        ...

    @abstractmethod
    def new_micro_batch_forward(self):
        ...

    @abstractmethod
    def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
        ...


@dataclasses.dataclass
class PipelineTrainBatchState(PipelineBatchState):
    microbatches_activations_to_send = collections.deque()
    microbatches_activations_to_recv = collections.deque()
    microbatches_grads_to_send = collections.deque()
    microbatches_grads_to_recv = collections.deque()
    grads_buffer = collections.deque()

    # List of list, first index represent micro_batch_id, second index represent activations that needs to be popped
    microbatches_activations_requiring_backward = collections.deque()

    # Reinitialise counter
    nb_backwards = 0
    nb_forwards = 0

    def register_activation_requiring_backward(self, activation: torch.Tensor):
        # Register the activation to last microbatch
        self.microbatches_activations_requiring_backward[-1].append(activation)

    def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
        # TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only send activations from higher ranks
        self.microbatches_activations_to_send.append(SendActivation(activation=activation, to_rank=to_rank, p2p=p2p))

    def register_recv_activation(self, from_rank: int, p2p: P2P):
        # TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only recv activations from lower ranks
        self.microbatches_activations_to_recv.append(RecvActivation(from_rank=from_rank, p2p=p2p))

    def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
        # TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only send gradients to lower ranks
        self.microbatches_grads_to_send.append(SendGrad(grad=grad, to_rank=to_rank, p2p=p2p))

    def register_recv_grad(self, from_rank: int, p2p: P2P):
        # TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only recv gradients from higher ranks
        self.microbatches_grads_to_recv.append(RecvGrad(from_rank=from_rank, p2p=p2p))

    def run_communication(self, send_only_activation: bool = False):
        """Run communication in a specific order: send activation, recv activation, send grad, recv grad
        Only one communication is done at a time."""
        log_rank(
            f"activation_to_send: {len(self.microbatches_activations_to_send)} | "
            f"activation_to_recv: {len(self.microbatches_activations_to_recv)} | "
            f"grads_to_send: {len(self.microbatches_grads_to_send)} | "
            f"grads_to_recv: {len(self.microbatches_grads_to_recv)} | "
            f"activation_buffer: {len(self.activations_buffer)} | "
            f"grads_buffer: {len(self.grads_buffer)}",
            logger=logger,
            level=logging.DEBUG,
        )
        # Pop one send activation
        if len(self.microbatches_activations_to_send) > 0:
            send_activation = self.microbatches_activations_to_send.popleft()
            # Execute
            activation_send_requires_grad = send_activation.activation.requires_grad
            send_activation()
            if send_only_activation:
                return

        # Pop one recv activation
        if len(self.microbatches_activations_to_recv) > 0:
            recv_activation = self.microbatches_activations_to_recv.popleft()
            # Execute
            recv_activation_tensor = recv_activation()
            self.activations_buffer.append(recv_activation_tensor)
            # If somehow you receive a tensor without the need of backward, you shouldn't do cross communication
            if recv_activation_tensor.requires_grad is False:
                return

        # Pop one send gradient
        if len(self.microbatches_grads_to_send) > 0:
            send_grad = self.microbatches_grads_to_send.popleft()
            # Execute
            send_grad()

        # Pop one recv gradient
        if len(self.microbatches_grads_to_recv) > 0:
            # Send activation until `activation_send_requires_grad` is True
            while len(self.microbatches_activations_to_send) > 0 and not activation_send_requires_grad:
                send_activation = self.microbatches_activations_to_send.popleft()
                # Execute
                activation_send_requires_grad = send_activation.activation.requires_grad
                send_activation()
            recv_grad = self.microbatches_grads_to_recv.popleft()
            # Execute
            self.grads_buffer.append(recv_grad())

        # TODO @thomasw21: I need some mechanism to point to whatever is now sorted in a buffer, typically some id that would point to the correct tensor in our buffer instead of relying on the sorted list.

    def new_micro_batch_forward(self):
        self.microbatches_activations_requiring_backward.append(collections.deque())

    def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
        return self.microbatches_activations_requiring_backward.popleft()

    def check_buffers_empty(self):
        assert (
            len(self.microbatches_activations_requiring_backward) == 0
        ), f"There are still activations that require backward: {len(self.microbatches_activations_requiring_backward)}"
        assert (
            len(self.microbatches_activations_to_send) == 0
        ), f"There are activations left for me to send still: {len(self.microbatches_activations_to_send)}"
        assert (
            len(self.microbatches_activations_to_recv) == 0
        ), f"There are activations left for me to recv still: {len(self.microbatches_activations_to_recv)}"
        assert (
            len(self.microbatches_grads_to_send) == 0
        ), f"There are gradients left for me to send still: {len(self.microbatches_grads_to_send)}"
        assert (
            len(self.microbatches_grads_to_recv) == 0
        ), f"There are gradients left for me to recv still: {len(self.microbatches_grads_to_recv)}"


@dataclasses.dataclass
class PipelineEvalBatchState(PipelineBatchState):
    microbatches_activations_to_send = collections.deque()
    microbatches_activations_to_recv = collections.deque()
    activations_buffer = collections.deque()

    def register_activation_requiring_backward(self, activation: torch.Tensor):
        pass

    def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
        self.microbatches_activations_to_send.append(SendActivation(activation=activation, to_rank=to_rank, p2p=p2p))

        # There's a cross communication
        if len(self.microbatches_activations_to_recv) > 0 and len(self.microbatches_activations_to_recv) > 0:
            self.run_communication()

    def register_recv_activation(self, from_rank: int, p2p: P2P):
        self.microbatches_activations_to_recv.append(RecvActivation(from_rank=from_rank, p2p=p2p))

        # There's a cross communication
        if len(self.microbatches_activations_to_recv) > 0 and len(self.microbatches_activations_to_recv) > 0:
            self.run_communication()

    def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
        raise NotImplementedError("You can't register a send grad in pipeline eval mode")

    def register_recv_grad(self, from_rank: int, p2p: P2P):
        raise NotImplementedError("You can't register a recv grad in pipeline eval mode")

    def new_micro_batch_forward(self):
        pass

    def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
        pass

    def run_communication(self, send_only_activation: bool = False):
        # four cases:
        #  - you receive from higher rank and you send to higher rank
        #  - You receive from higher rank and you send to lower rank
        #  - you receive from lower rank and you send to higher rank
        #  - you receive from lower rank and you send to lower rank

        send_activation = None
        # Pop all send activation
        for _ in range(min(1, len(self.microbatches_activations_to_send))):
            send_activation = self.microbatches_activations_to_send.popleft()

        # Pop all recv activation
        recv_activation = None
        for _ in range(min(1, len(self.microbatches_activations_to_recv))):
            recv_activation = self.microbatches_activations_to_recv.popleft()

        if send_activation is None:
            if recv_activation is None:
                raise ValueError("Why the hell do we communicate when there's nothing to communicate?")
            self.activations_buffer.append(recv_activation())
        else:
            if recv_activation is None:
                send_activation()
            else:
                # Define in which order to we do it.
                # Actually we can't do any heuristics as you need global information in order to define clear ordering.
                # We make a BIG assumption that only ONE rank receives from higher rank and sends to higher rank.
                # In this case we find the "lowest" rank, send first
                # All the other ranks receive first and send after
                # Lowest rank receives.
                # If we knew who was involved in the cycle, we could just randomly choose one rank to first send then recv, however it's not clear who's involved
                p2p = send_activation.p2p
                assert p2p == recv_activation.p2p
                is_lowest = send_activation.to_rank > dist.get_rank(
                    p2p.pg
                ) and recv_activation.from_rank > dist.get_rank(p2p.pg)
                if is_lowest:
                    send_activation()
                    self.activations_buffer.append(recv_activation())
                else:
                    self.activations_buffer.append(recv_activation())
                    send_activation()

    def check_buffers_empty(self):
        assert (
            len(self.microbatches_activations_to_send) == 0
        ), f"There are activations left for me to send still: {len(self.microbatches_activations_to_send)}"
        assert (
            len(self.microbatches_activations_to_recv) == 0
        ), f"There are activations left for me to recv still: {len(self.microbatches_activations_to_recv)}"
        assert (
            len(self.activations_buffer) == 0
        ), f"There are activations left in the buffer: {len(self.activations_buffer)}"