Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
import torch
from nanotron import logging
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState
logger = logging.get_logger(__name__)
class SendTensorToPipelineBuffer(torch.autograd.Function):
"""Make sending tensors differentiable. The difference is here we don't use `torch.distributed` primites, but store events that's we will pop whenever we need"""
@staticmethod
def forward(
ctx,
activation: torch.Tensor,
to_rank: int,
p2p: P2P,
pipeline_state: PipelineBatchState,
):
assert activation.requires_grad
ctx.p2p = p2p
ctx.to_rank = to_rank
ctx.pipeline_state = pipeline_state
# Send tensors
pipeline_state.register_send_activation(activation, to_rank=to_rank, p2p=p2p)
# HACK @thomasw21: This forces the trigger to backward
return torch.tensor(1, dtype=torch.float, device="cpu", requires_grad=True)
@staticmethod
def backward(ctx, grad_tensor):
p2p = ctx.p2p
to_rank = ctx.to_rank
pipeline_state = ctx.pipeline_state
# send a gradient and store it in buffer
pipeline_state.register_recv_grad(from_rank=to_rank, p2p=p2p)
if len(pipeline_state.grads_buffer) == 0:
pipeline_state.run_communication()
grad_tensor = pipeline_state.grads_buffer.popleft()
return grad_tensor, None, None, None
class SendTensorWithoutGradientToPipelineBuffer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
dummy_input: torch.Tensor,
activation: torch.Tensor,
to_rank: int,
p2p: P2P,
pipeline_state: PipelineBatchState,
):
assert dummy_input.requires_grad
assert activation.requires_grad is False
ctx.p2p = p2p
ctx.to_rank = to_rank
ctx.pipeline_state = pipeline_state
# Send tensors
pipeline_state.register_send_activation(activation, to_rank=to_rank, p2p=p2p)
# HACK @thomasw21: This forces the trigger to backward
return torch.tensor(1, dtype=torch.float, device="cpu", requires_grad=True)
@staticmethod
def backward(ctx, grad_tensor):
pipeline_state = ctx.pipeline_state
# send only the activations
pipeline_state.run_communication(send_only_activation=True)
return None, None, None, None, None
def send_to_pipeline_state_buffer(tensor: torch.Tensor, to_rank: int, p2p: P2P, pipeline_state: PipelineBatchState):
# This is used in order to know where to backward from.
if tensor.requires_grad:
result = SendTensorToPipelineBuffer.apply(tensor, to_rank, p2p, pipeline_state)
else:
# Trick that backward mechanism to just send the tensor.
dummy_input = torch.empty(1, dtype=torch.float, requires_grad=True, device="cpu")
result = SendTensorWithoutGradientToPipelineBuffer.apply(dummy_input, tensor, to_rank, p2p, pipeline_state)
pipeline_state.register_activation_requiring_backward(result)
class RecvTensorFromPipelineBuffer(torch.autograd.Function):
"""Make receiving tensors differentiable"""
@staticmethod
def forward(ctx, activation: torch.Tensor, from_rank: int, p2p: P2P, pipeline_state: PipelineBatchState):
ctx.pipeline_state = pipeline_state
ctx.p2p = p2p
ctx.from_rank = from_rank
return activation
@staticmethod
def backward(ctx, grad_tensor):
pipeline_state = ctx.pipeline_state
from_rank = ctx.from_rank
p2p = ctx.p2p
# Send tensors
pipeline_state.register_send_grad(grad_tensor, to_rank=from_rank, p2p=p2p)
return None, None, None, None
def recv_from_pipeline_state_buffer(from_rank: int, p2p: P2P, pipeline_state: PipelineBatchState):
pipeline_state.register_recv_activation(from_rank=from_rank, p2p=p2p)
if len(pipeline_state.activations_buffer) == 0:
pipeline_state.run_communication()
activation = pipeline_state.activations_buffer.popleft()
return RecvTensorFromPipelineBuffer.apply(activation, from_rank, p2p, pipeline_state)
import dataclasses
from typing import List, Sequence, Tuple
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage
logger = logging.get_logger(__name__)
FIRST_METADATA_SIZE = 7
SECOND_METADATA_SIZE = 1024
ID_TO_DTYPE = [
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
ID_TO_REQUIRES_GRAD = [True, False]
REQUIRES_GRAD_TO_ID = {value: id_ for id_, value in enumerate(ID_TO_REQUIRES_GRAD)}
ID_TO_IS_CONTIGUOUS = [True, False]
IS_CONTIGUOUS_TO_ID = {value: id_ for id_, value in enumerate(ID_TO_IS_CONTIGUOUS)}
@dataclasses.dataclass
class P2PTensorMetaData:
shape: Sequence[int]
stride: Sequence[int]
is_contiguous: bool
untyped_storage_size: int
storage_offset: int
dtype: torch.dtype
requires_grad: bool
def create_empty_storage(self, device: torch.device) -> torch.Tensor:
buffer = torch.empty(
size=(self.untyped_storage_size,),
requires_grad=False,
dtype=torch.int8,
device=device,
memory_format=torch.contiguous_format,
).view(dtype=self.dtype)
buffer.requires_grad = self.requires_grad
if self.is_contiguous:
buffer = buffer.as_strided(
size=tuple(self.shape), stride=tuple(self.stride), storage_offset=self.storage_offset
)
# Complex needs to be viewed as real first
# TODO @thomasw21: Find the issue with send/recv complex tensors
buffer = torch.view_as_real(buffer) if self.dtype.is_complex else buffer
return buffer
def reshape(self, buffer):
"""Changes the way we view buffer in order to fit metadata"""
# TODO @thomasw21: Find the issue with send/recv complex tensors
buffer = torch.view_as_complex(buffer) if self.dtype.is_complex else buffer
# Set shape and stride
if not self.is_contiguous:
buffer = buffer.as_strided(
size=tuple(self.shape), stride=tuple(self.stride), storage_offset=self.storage_offset
)
return buffer
@staticmethod
def to_first_metadata(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
# TODO @nouamane: avoid having two metadata comms, and preallocate shape/stride instead
return torch.tensor(
[
len(tensor.shape),
len(tensor.stride()),
IS_CONTIGUOUS_TO_ID[tensor.is_contiguous()],
get_untyped_storage(tensor).size(),
tensor.storage_offset(),
DTYPE_TO_ID[tensor.dtype],
REQUIRES_GRAD_TO_ID[tensor.requires_grad],
],
dtype=torch.long,
device=device,
)
@staticmethod
def to_second_metadata(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
return torch.tensor(tensor.shape + tensor.stride(), dtype=torch.long, device=device)
@classmethod
def from_metadata(cls, first_metadata: List[int], second_metadata: List[int]):
shape_and_stride = second_metadata
(
num_shape,
num_stride,
is_contiguous,
untyped_storage_size,
storage_offset,
dtype_id,
requires_grad_id,
) = first_metadata
return cls(
shape=shape_and_stride[: len(shape_and_stride) // 2],
stride=shape_and_stride[len(shape_and_stride) // 2 :],
is_contiguous=ID_TO_IS_CONTIGUOUS[is_contiguous],
untyped_storage_size=untyped_storage_size,
storage_offset=storage_offset,
dtype=ID_TO_DTYPE[dtype_id],
requires_grad=ID_TO_REQUIRES_GRAD[requires_grad_id],
)
def view_as_contiguous(tensor: torch.Tensor):
"""Given a tensor, we want to view the tensor as a contiguous storage"""
tensor_numel = tensor.numel()
tensor_element_size = tensor.element_size()
untyped_storage = get_untyped_storage(tensor)
untyped_storage_size = untyped_storage.size()
untyped_element_size = untyped_storage.element_size()
assert (
tensor_numel * tensor_element_size >= untyped_storage_size * untyped_element_size
), "Expect storage_size to be smaller than tensor size. It might not be true, when you use slicing for example though. We probably don't want to support it in our P2P system"
buffer = tensor_from_untyped_storage(untyped_storage=untyped_storage, dtype=tensor.dtype)
return buffer
class P2P:
def __init__(self, pg: dist.ProcessGroup, device: torch.device):
self.pg = pg
self.device = device
self.first_metadata = torch.empty(FIRST_METADATA_SIZE, dtype=torch.long, device=self.device)
self.second_metadata = torch.empty(SECOND_METADATA_SIZE, dtype=torch.long, device=self.device)
def _send_first_metadata_p2p_op(self, tensor: torch.Tensor, to_rank: int, tag: int = 0) -> dist.P2POp:
first_metadata = P2PTensorMetaData.to_first_metadata(tensor=tensor, device=self.device)
return dist.P2POp(
op=dist.isend,
tensor=first_metadata,
peer=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_first_metadata_p2p_op(self, from_rank: int, tag: int = 0) -> Tuple[torch.Tensor, dist.P2POp]:
first_metadata_buffer = torch.empty((FIRST_METADATA_SIZE,), dtype=torch.long, device=self.device)
return first_metadata_buffer, dist.P2POp(
op=dist.irecv,
tensor=first_metadata_buffer,
peer=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
def _send_second_metadata_p2p_op(self, tensor: torch.Tensor, to_rank: int, tag: int = 0) -> dist.P2POp:
second_metadata = P2PTensorMetaData.to_second_metadata(tensor=tensor, device=self.device)
return dist.P2POp(
op=dist.isend,
tensor=second_metadata,
peer=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_second_metadata_p2p_op(
self, shape_length: int, stride_length: int, from_rank: int, tag: int = 0
) -> Tuple[torch.Tensor, dist.P2POp]:
second_metadata_buffer = torch.empty((shape_length + stride_length,), dtype=torch.long, device=self.device)
return second_metadata_buffer, dist.P2POp(
op=dist.irecv,
tensor=second_metadata_buffer,
peer=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
def _send_data_p2p_op(self, tensor: torch.Tensor, to_rank: int, tag: int = 0) -> dist.P2POp:
return dist.P2POp(
op=dist.isend,
tensor=tensor,
peer=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_data_p2p_op(
self, tensor_metadata: P2PTensorMetaData, from_rank: int, tag: int = 0
) -> Tuple[torch.Tensor, dist.P2POp]:
tensor_buffer = tensor_metadata.create_empty_storage(self.device)
return tensor_buffer, dist.P2POp(
op=dist.irecv,
tensor=tensor_buffer,
peer=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
def _send_meta(self, tensor: torch.Tensor, to_rank: int, tag: int):
cpu_tensor = torch.tensor(
[
len(tensor.shape),
len(tensor.stride()),
IS_CONTIGUOUS_TO_ID[tensor.is_contiguous()],
get_untyped_storage(tensor).size(),
tensor.storage_offset(),
DTYPE_TO_ID[tensor.dtype],
REQUIRES_GRAD_TO_ID[tensor.requires_grad],
],
dtype=torch.long,
)
self.first_metadata.copy_(cpu_tensor)
dist.send(
self.first_metadata,
dst=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
second_metadata = tensor.shape + tensor.stride()
assert len(tensor.shape) == self.first_metadata[0]
assert len(tensor.stride()) == self.first_metadata[1]
# increase buffer size
if len(second_metadata) > len(self.second_metadata):
self.second_metadata = torch.empty(len(second_metadata), dtype=torch.long, device=self.device)
self.second_metadata[: len(second_metadata)].copy_(torch.tensor(second_metadata, dtype=torch.long))
dist.send(
self.second_metadata[: len(second_metadata)],
dst=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_meta(self, from_rank: int, tag: int) -> P2PTensorMetaData:
dist.recv(
self.first_metadata,
src=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
(
num_shape,
num_stride,
is_contiguous,
untyped_storage_size,
storage_offset,
dtype_id,
requires_grad_id,
) = self.first_metadata
# self.pg.recv([second], from_rank, 0).wait() # more direct API
second_metadata_num_elements = num_shape + num_stride
# increase buffer size
if second_metadata_num_elements > len(self.second_metadata):
self.second_metadata = torch.empty(second_metadata_num_elements, dtype=torch.long, device=self.device)
dist.recv(
self.second_metadata[:second_metadata_num_elements],
src=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
shape = self.second_metadata[:num_shape]
stride = self.second_metadata[num_shape:second_metadata_num_elements]
return P2PTensorMetaData(
dtype=ID_TO_DTYPE[dtype_id],
requires_grad=ID_TO_REQUIRES_GRAD[requires_grad_id],
shape=shape,
stride=stride,
is_contiguous=ID_TO_IS_CONTIGUOUS[is_contiguous],
untyped_storage_size=untyped_storage_size,
storage_offset=storage_offset,
)
def isend_tensors(self, tensors: List[torch.Tensor], to_rank: int, tag: int = 0) -> List[dist.Work]:
futures = []
current_rank = dist.get_rank(self.pg)
logger.debug(f"Current rank {current_rank} sending to rank {to_rank}. Nb_tensors: {len(tensors)}")
for tensor in tensors:
if to_rank != current_rank:
self._send_meta(tensor, to_rank=to_rank, tag=tag)
if tensor.is_contiguous():
buffer = tensor
else:
# If the tensor is not contiguous we send the entire storage
buffer = view_as_contiguous(tensor)
# TODO @thomasw21: Find the issue with send/recv complex tensors
buffer = torch.view_as_real(buffer) if buffer.is_complex() else buffer
futures.append(
dist.isend(
buffer,
dst=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
)
else:
raise ValueError("Tried sending tensor to itself")
return futures
def irecv_tensors(
self, num_tensors: int, from_rank: int, tag: int = 0
) -> Tuple[List[torch.Tensor], List[dist.Work]]:
futures = []
buffers = []
current_rank = dist.get_rank(self.pg)
logger.debug(f"Current rank {current_rank} receiving from rank {from_rank}. Nb_tensors: {num_tensors}")
for _ in range(num_tensors):
if from_rank != current_rank:
meta = self._recv_meta(from_rank=from_rank, tag=tag)
buffer = meta.create_empty_storage(device=self.device)
futures.append(
dist.irecv(
buffer,
src=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
)
buffer = meta.reshape(buffer=buffer)
# Add to the list
buffers.append(buffer)
else:
raise ValueError("Tried receiving tensor from itself")
return buffers, futures
def send_tensors(self, tensors: List[torch.Tensor], to_rank: int, tag: int = 0):
futures = self.isend_tensors(tensors=tensors, to_rank=to_rank, tag=tag)
for future in futures:
future.wait()
def recv_tensors(self, num_tensors: int, from_rank: int, tag: int = 0) -> List[torch.Tensor]:
buffers, futures = self.irecv_tensors(num_tensors=num_tensors, from_rank=from_rank, tag=tag)
for future in futures:
future.wait()
return buffers
class BatchTensorSendRecvState:
"""
This class is used to register send/recv batches of tensors, and
then executes send/recv in `flush()` calls. This is useful for
amortizing the cost of sending and receiving tensors over multiple
iterations.
"""
p2p: P2P
first_metadata_p2p_ops: List[dist.P2POp]
second_metadata_p2p_ops: List[dist.P2POp]
data_p2p_ops: List[dist.P2POp]
recv_first_metadata_buffers: List[torch.Tensor]
recv_from_ranks: List[int]
def __init__(self, p2p: P2P):
self.p2p = p2p
self._reset()
def _reset(self):
self.first_metadata_p2p_ops: List[dist.P2POp] = []
self.second_metadata_p2p_ops: List[dist.P2POp] = []
self.data_p2p_ops: List[dist.P2POp] = []
self.recv_first_metadata_buffers: List[torch.Tensor] = []
self.recv_from_ranks: List[int] = []
def __str__(self):
return f"BatchTensorSendRecvState(first_metadata_p2p_ops={len(self.first_metadata_p2p_ops)}, second_metadata_p2p_ops={len(self.second_metadata_p2p_ops)}, data_p2p_ops={len(self.data_p2p_ops)}, recv_first_metadata_buffers={len(self.recv_first_metadata_buffers)}, recv_from_ranks={self.recv_from_ranks})"
def add_send(self, tensor: torch.Tensor, to_rank: int, tag: int = 0):
self.first_metadata_p2p_ops.append(
self.p2p._send_first_metadata_p2p_op(tensor=tensor, to_rank=to_rank, tag=tag)
)
self.second_metadata_p2p_ops.append(
self.p2p._send_second_metadata_p2p_op(tensor=tensor, to_rank=to_rank, tag=tag)
)
self.data_p2p_ops.append(
self.p2p._send_data_p2p_op(tensor=view_as_contiguous(tensor), to_rank=to_rank, tag=tag)
)
def add_recv(self, from_rank: int, tag: int = 0) -> int:
"""
Only add p2p ops for the first operation, as `_recv_second_metadata` and `_recv_data_p2p_op`
require results from the first metadata to be transfered first.
Return: index of the recv_buffer in `self.recv_first_metadata_buffers`
"""
buffer, recv_op = self.p2p._recv_first_metadata_p2p_op(from_rank=from_rank, tag=tag)
self.first_metadata_p2p_ops.append(recv_op)
self.recv_first_metadata_buffers.append(buffer)
self.recv_from_ranks.append(from_rank)
return len(self.recv_first_metadata_buffers) - 1
def _send_recv_first_metadata(self) -> List[List[int]]:
# Send/Recv first metadata
reqs = dist.batch_isend_irecv(self.first_metadata_p2p_ops)
for req in reqs:
req.wait()
# We want an early cpu/gpu sync here as we are right after the wait so it's nearly free.
# Removing the tolist call here delays the sync and will impact performance.
# We need to instantiate it in a list because it is used twice
first_metadatas = [tensor.tolist() for tensor in self.recv_first_metadata_buffers]
return first_metadatas
def _send_recv_second_metadata(self, first_metadata: List[List[int]]) -> List[List[int]]:
# turn a list of tuple into a tuple of list
recv_second_metadata_buffers, recv_second_metadata_ops = zip(
*(
self.p2p._recv_second_metadata_p2p_op(
shape_length=num_shape, stride_length=num_stride, from_rank=from_rank
)
for (num_shape, num_stride, *_), from_rank in zip(first_metadata, self.recv_from_ranks)
)
)
recv_second_metadata_ops = list(recv_second_metadata_ops)
# Send/Recv second metadata
reqs = dist.batch_isend_irecv(self.second_metadata_p2p_ops + recv_second_metadata_ops)
for req in reqs:
req.wait()
# We want an early cpu/gpu sync here as we are right after the wait so it's nearly free.
# Removing the tolist call here delays the sync and will impact performance.
second_metadatas = [tensor.tolist() for tensor in recv_second_metadata_buffers]
return second_metadatas
def _send_recv_data(self, tensor_metadatas: List[P2PTensorMetaData]) -> List[torch.Tensor]:
# turn a list of tuples into a tuple of list
recv_data_buffers, recv_data_ops = zip(
*(
self.p2p._recv_data_p2p_op(tensor_metadata=tensor_metadata, from_rank=from_rank)
for tensor_metadata, from_rank in zip(tensor_metadatas, self.recv_from_ranks)
)
)
recv_data_ops = list(recv_data_ops)
# Send/Recv tensor data
futures = dist.batch_isend_irecv(self.data_p2p_ops + recv_data_ops)
for future in futures:
future.wait()
# Format tensor by setting the stride
return [
recv_data_buffer.as_strided(size=tuple(tensor_metadata.shape), stride=tuple(tensor_metadata.stride))
for recv_data_buffer, tensor_metadata in zip(recv_data_buffers, tensor_metadatas)
]
def flush(self) -> List[torch.Tensor]:
"""
Run all communication in a batch.
Return `torch.Tensor` in the case of recv.
"""
assert len(self.recv_first_metadata_buffers) == len(
self.recv_from_ranks
), f"len(self.recv_first_metadata_buffers)={len(self.recv_first_metadata_buffers)}, len(self.recv_from_ranks)={len(self.recv_from_ranks)} but should be equal."
# If there is no communication, return
if len(self.first_metadata_p2p_ops) == 0:
return []
# If there is no recv
if len(self.recv_first_metadata_buffers) == 0:
reqs = dist.batch_isend_irecv(
self.first_metadata_p2p_ops + self.second_metadata_p2p_ops + self.data_p2p_ops
)
for req in reqs:
req.wait()
self._reset()
return []
# Send/Recv first metadata
logger.debug(f"First metadata: {[p2pop.op for p2pop in self.first_metadata_p2p_ops]}")
# TODO(kunhao): We could actually send all at once like the above no recv case. But I need to benchmark the performance.
first_metadatas = self._send_recv_first_metadata()
# Send/Recv second metadata
second_metadatas = self._send_recv_second_metadata(first_metadatas)
tensor_metadatas = [
P2PTensorMetaData.from_metadata(first_metadata, second_metadata)
for first_metadata, second_metadata in zip(first_metadatas, second_metadatas)
]
recv_tensors = self._send_recv_data(tensor_metadatas)
# Reset state
self._reset()
return recv_tensors
import collections
import dataclasses
from abc import ABC, abstractmethod
from typing import List
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.p2p import P2P
logger = logging.get_logger(__name__)
@dataclasses.dataclass
class SendActivation:
activation: torch.Tensor
to_rank: int
p2p: P2P
def __call__(self):
self.p2p.send_tensors([self.activation], to_rank=self.to_rank)
@dataclasses.dataclass
class RecvActivation:
from_rank: int
p2p: P2P
def __call__(self) -> torch.Tensor:
return self.p2p.recv_tensors(num_tensors=1, from_rank=self.from_rank)[0]
@dataclasses.dataclass
class SendGrad:
grad: torch.Tensor
to_rank: int
p2p: P2P
def __call__(self):
self.p2p.send_tensors([self.grad], to_rank=self.to_rank)
@dataclasses.dataclass
class RecvGrad:
from_rank: int
p2p: P2P
def __call__(self) -> torch.Tensor:
return self.p2p.recv_tensors(num_tensors=1, from_rank=self.from_rank)[0]
class PipelineBatchState(ABC):
activations_buffer = collections.deque()
@abstractmethod
def register_activation_requiring_backward(self, activation: torch.Tensor):
...
@abstractmethod
def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
...
@abstractmethod
def register_recv_activation(self, from_rank: int, p2p: P2P):
...
@abstractmethod
def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
...
@abstractmethod
def register_recv_grad(self, from_rank: int, p2p: P2P):
...
@abstractmethod
def run_communication(self, send_only_activation: bool = False):
...
@abstractmethod
def new_micro_batch_forward(self):
...
@abstractmethod
def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
...
@dataclasses.dataclass
class PipelineTrainBatchState(PipelineBatchState):
microbatches_activations_to_send = collections.deque()
microbatches_activations_to_recv = collections.deque()
microbatches_grads_to_send = collections.deque()
microbatches_grads_to_recv = collections.deque()
grads_buffer = collections.deque()
# List of list, first index represent micro_batch_id, second index represent activations that needs to be popped
microbatches_activations_requiring_backward = collections.deque()
# Reinitialise counter
nb_backwards = 0
nb_forwards = 0
def register_activation_requiring_backward(self, activation: torch.Tensor):
# Register the activation to last microbatch
self.microbatches_activations_requiring_backward[-1].append(activation)
def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only send activations from higher ranks
self.microbatches_activations_to_send.append(SendActivation(activation=activation, to_rank=to_rank, p2p=p2p))
def register_recv_activation(self, from_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only recv activations from lower ranks
self.microbatches_activations_to_recv.append(RecvActivation(from_rank=from_rank, p2p=p2p))
def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only send gradients to lower ranks
self.microbatches_grads_to_send.append(SendGrad(grad=grad, to_rank=to_rank, p2p=p2p))
def register_recv_grad(self, from_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only recv gradients from higher ranks
self.microbatches_grads_to_recv.append(RecvGrad(from_rank=from_rank, p2p=p2p))
def run_communication(self, send_only_activation: bool = False):
"""Run communication in a specific order: send activation, recv activation, send grad, recv grad
Only one communication is done at a time."""
log_rank(
f"activation_to_send: {len(self.microbatches_activations_to_send)} | "
f"activation_to_recv: {len(self.microbatches_activations_to_recv)} | "
f"grads_to_send: {len(self.microbatches_grads_to_send)} | "
f"grads_to_recv: {len(self.microbatches_grads_to_recv)} | "
f"activation_buffer: {len(self.activations_buffer)} | "
f"grads_buffer: {len(self.grads_buffer)}",
logger=logger,
level=logging.DEBUG,
)
# Pop one send activation
if len(self.microbatches_activations_to_send) > 0:
send_activation = self.microbatches_activations_to_send.popleft()
# Execute
activation_send_requires_grad = send_activation.activation.requires_grad
send_activation()
if send_only_activation:
return
# Pop one recv activation
if len(self.microbatches_activations_to_recv) > 0:
recv_activation = self.microbatches_activations_to_recv.popleft()
# Execute
recv_activation_tensor = recv_activation()
self.activations_buffer.append(recv_activation_tensor)
# If somehow you receive a tensor without the need of backward, you shouldn't do cross communication
if recv_activation_tensor.requires_grad is False:
return
# Pop one send gradient
if len(self.microbatches_grads_to_send) > 0:
send_grad = self.microbatches_grads_to_send.popleft()
# Execute
send_grad()
# Pop one recv gradient
if len(self.microbatches_grads_to_recv) > 0:
# Send activation until `activation_send_requires_grad` is True
while len(self.microbatches_activations_to_send) > 0 and not activation_send_requires_grad:
send_activation = self.microbatches_activations_to_send.popleft()
# Execute
activation_send_requires_grad = send_activation.activation.requires_grad
send_activation()
recv_grad = self.microbatches_grads_to_recv.popleft()
# Execute
self.grads_buffer.append(recv_grad())
# TODO @thomasw21: I need some mechanism to point to whatever is now sorted in a buffer, typically some id that would point to the correct tensor in our buffer instead of relying on the sorted list.
def new_micro_batch_forward(self):
self.microbatches_activations_requiring_backward.append(collections.deque())
def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
return self.microbatches_activations_requiring_backward.popleft()
def check_buffers_empty(self):
assert (
len(self.microbatches_activations_requiring_backward) == 0
), f"There are still activations that require backward: {len(self.microbatches_activations_requiring_backward)}"
assert (
len(self.microbatches_activations_to_send) == 0
), f"There are activations left for me to send still: {len(self.microbatches_activations_to_send)}"
assert (
len(self.microbatches_activations_to_recv) == 0
), f"There are activations left for me to recv still: {len(self.microbatches_activations_to_recv)}"
assert (
len(self.microbatches_grads_to_send) == 0
), f"There are gradients left for me to send still: {len(self.microbatches_grads_to_send)}"
assert (
len(self.microbatches_grads_to_recv) == 0
), f"There are gradients left for me to recv still: {len(self.microbatches_grads_to_recv)}"
@dataclasses.dataclass
class PipelineEvalBatchState(PipelineBatchState):
microbatches_activations_to_send = collections.deque()
microbatches_activations_to_recv = collections.deque()
activations_buffer = collections.deque()
def register_activation_requiring_backward(self, activation: torch.Tensor):
pass
def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
self.microbatches_activations_to_send.append(SendActivation(activation=activation, to_rank=to_rank, p2p=p2p))
# There's a cross communication
if len(self.microbatches_activations_to_recv) > 0 and len(self.microbatches_activations_to_recv) > 0:
self.run_communication()
def register_recv_activation(self, from_rank: int, p2p: P2P):
self.microbatches_activations_to_recv.append(RecvActivation(from_rank=from_rank, p2p=p2p))
# There's a cross communication
if len(self.microbatches_activations_to_recv) > 0 and len(self.microbatches_activations_to_recv) > 0:
self.run_communication()
def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
raise NotImplementedError("You can't register a send grad in pipeline eval mode")
def register_recv_grad(self, from_rank: int, p2p: P2P):
raise NotImplementedError("You can't register a recv grad in pipeline eval mode")
def new_micro_batch_forward(self):
pass
def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
pass
def run_communication(self, send_only_activation: bool = False):
# four cases:
# - you receive from higher rank and you send to higher rank
# - You receive from higher rank and you send to lower rank
# - you receive from lower rank and you send to higher rank
# - you receive from lower rank and you send to lower rank
send_activation = None
# Pop all send activation
for _ in range(min(1, len(self.microbatches_activations_to_send))):
send_activation = self.microbatches_activations_to_send.popleft()
# Pop all recv activation
recv_activation = None
for _ in range(min(1, len(self.microbatches_activations_to_recv))):
recv_activation = self.microbatches_activations_to_recv.popleft()
if send_activation is None:
if recv_activation is None:
raise ValueError("Why the hell do we communicate when there's nothing to communicate?")
self.activations_buffer.append(recv_activation())
else:
if recv_activation is None:
send_activation()
else:
# Define in which order to we do it.
# Actually we can't do any heuristics as you need global information in order to define clear ordering.
# We make a BIG assumption that only ONE rank receives from higher rank and sends to higher rank.
# In this case we find the "lowest" rank, send first
# All the other ranks receive first and send after
# Lowest rank receives.
# If we knew who was involved in the cycle, we could just randomly choose one rank to first send then recv, however it's not clear who's involved
p2p = send_activation.p2p
assert p2p == recv_activation.p2p
is_lowest = send_activation.to_rank > dist.get_rank(
p2p.pg
) and recv_activation.from_rank > dist.get_rank(p2p.pg)
if is_lowest:
send_activation()
self.activations_buffer.append(recv_activation())
else:
self.activations_buffer.append(recv_activation())
send_activation()
def check_buffers_empty(self):
assert (
len(self.microbatches_activations_to_send) == 0
), f"There are activations left for me to send still: {len(self.microbatches_activations_to_send)}"
assert (
len(self.microbatches_activations_to_recv) == 0
), f"There are activations left for me to recv still: {len(self.microbatches_activations_to_recv)}"
assert (
len(self.activations_buffer) == 0
), f"There are activations left in the buffer: {len(self.activations_buffer)}"
import dataclasses
@dataclasses.dataclass
class TensorPointer:
"""Dataclass specifying from which rank we need to query a tensor from in order to access data"""
# Needed to understand from which rank to get the tensor
# TODO @thomasw21: Maybe add which group it belongs to as well? Typically this is highly correlated to `p2p.pg`
group_rank: int
# TODO @thomasw21: Maybe add a tag (torch.distributed.send/recv allow for tagging)
from nanotron.models import NanotronModel
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from torch import nn
from torch.nn.parallel import DistributedDataParallel
def get_input_output_pp_ranks(model: NanotronModel | DistributedDataParallel):
if isinstance(model, DistributedDataParallel):
input_pp_rank = model.module.input_pp_rank
output_pp_rank = model.module.output_pp_rank
else:
input_pp_rank = model.input_pp_rank
output_pp_rank = model.output_pp_rank
return input_pp_rank, output_pp_rank
def get_pp_rank_of(target: str, module: nn.Module):
"""Assuming a model with pipeline blocks, we want to know in which pp rank the module/parameter whose name is `target`"""
if isinstance(module, PipelineBlock):
return module.rank
atoms = target.split(".")
current_module = module
for atom in atoms:
if not hasattr(current_module, atom):
raise AttributeError(f'{current_module._get_name()} has no attribute `"{atom}"`')
current_module = getattr(current_module, atom)
if isinstance(current_module, PipelineBlock):
return current_module.rank
if not isinstance(current_module, nn.Module):
raise AttributeError(f'`"{atom}"` is not an nn.Module')
raise ValueError(f'`"{target}" is not inside a PipelineBlock and thus does not have a pp_rank')
import dataclasses
from typing import List, Optional, Tuple
import numpy as np
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel.parameters import NanotronParameter, SlicesPair
@dataclasses.dataclass
class SplitConfig:
split_dim: int
# contiguous_chunks is a tuple of chunk sizes along the split_dim
# sharding happens inside each chunk
# if None, by default contiguous_chunks = (len(unsharded_param.shape[split_dim]),)
contiguous_chunks: Optional[Tuple[int, ...]] = None
def create_sharded_parameter(
parameter: nn.Parameter,
global_ranks: Tuple[int, ...],
local_global_slices_pairs: Tuple[SlicesPair, ...],
unsharded_shape: Tuple[int, ...],
) -> NanotronParameter:
if not isinstance(parameter, NanotronParameter):
parameter = NanotronParameter(tensor=parameter)
parameter.mark_as_sharded(
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
)
return parameter
def create_sharded_parameter_from_config(
parameter: nn.Parameter,
pg: dist.ProcessGroup,
split_config: SplitConfig,
) -> NanotronParameter:
current_rank = dist.get_rank(pg)
param_num_dims = len(parameter.shape)
global_ranks = dist.get_global_ranks(pg)
split_dim = split_config.split_dim
assert split_dim < param_num_dims
contiguous_chunks = split_config.contiguous_chunks
if contiguous_chunks is None:
# we are assuming that the parameter is contiguous along the split_dim, i.e. 1 whole chunk
# all parameters are equally shardable across the process group along the split_dim
shard_length = parameter.shape[split_dim]
global_slice = slice(current_rank * shard_length, (current_rank + 1) * shard_length)
# construct a mapping from local slices to global slices, multi-dimensional version
local_slices = tuple(slice(None) for _ in range(param_num_dims))
global_slices = tuple(global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims))
local_global_slices_pairs = (SlicesPair(local_slices=local_slices, global_slices=global_slices),)
unsharded_shape = tuple(
pg.size() * param_dim_size if dim_id == split_dim else param_dim_size
for dim_id, param_dim_size in enumerate(parameter.shape)
)
else:
# support custom contiguous chunk size for sharding each along the split_dim
local_global_slices_pairs: List[SlicesPair] = []
chunks_global_offset = np.cumsum((0,) + contiguous_chunks)
chunks_local_offset = chunks_global_offset // pg.size()
for chunk, chunk_global_start, chunk_local_start, chunk_local_end in zip(
contiguous_chunks,
chunks_global_offset[:-1],
chunks_local_offset[:-1],
chunks_local_offset[1:],
strict=True,
):
# we assume that we are doing equal split at the chunk level
assert chunk % pg.size() == 0, f"chunk size {chunk} must be divisible by process group size {pg.size()}"
shard_length = chunk // pg.size()
# we have: chunk_local_end = chunk_local_start + shard_length
local_slice = slice(chunk_local_start, chunk_local_end)
global_slice = slice(
current_rank * shard_length + chunk_global_start,
(current_rank + 1) * shard_length + chunk_global_start,
)
local_slices = tuple(
local_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims)
)
global_slices = tuple(
global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims)
)
local_global_slices_pairs.append(SlicesPair(local_slices=local_slices, global_slices=global_slices))
local_global_slices_pairs: Tuple[SlicesPair, ...] = tuple(local_global_slices_pairs)
unsharded_shape = tuple(
chunks_global_offset[-1] if dim_id == split_dim else param_dim_size
for dim_id, param_dim_size in enumerate(parameter.shape)
)
return create_sharded_parameter(
parameter=parameter,
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
)
def mark_all_parameters_in_module_as_sharded(module: nn.Module, pg: dist.ProcessGroup, split_config: SplitConfig):
"""
Mark parameters as sharded within a module. We assume that parameters are equally shardable across the process group.
:param module: nn.Module
:param pg: dist.ProcessGroup
:param split_config: SplitConfig
:return:
"""
for module_name, submodule in module.named_modules():
for param_name, param in list(submodule.named_parameters(recurse=False)):
new_param = create_sharded_parameter_from_config(parameter=param, pg=pg, split_config=split_config)
setattr(submodule, param_name, new_param)
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
class DifferentiableIdentity(torch.autograd.Function):
"""All-reduce gradients in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
return tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllReduceSum.apply(grad_output, group), None
class DifferentiableAllReduceSum(torch.autograd.Function):
"""All-reduce in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
if group.size() == 1:
return tensor
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class DifferentiableAllGather(torch.autograd.Function):
"""All gather in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
if group.size() == 1:
return tensor
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = torch.empty(
unsharded_batch_size,
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group)
return unsharded_tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None
class DifferentiableReduceScatterSum(torch.autograd.Function):
"""Reduce scatter in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
if group.size() == 1:
return tensor
# TODO @thomasw21: shard along another dimension
unsharded_batch_size, *rest_size = tensor.shape
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
assert unsharded_batch_size % group.size() == 0
# TODO @thomasw21: Collectives seem to require tensors to be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
tensor = tensor.contiguous()
sharded_tensor = torch.empty(
unsharded_batch_size // group.size(),
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllGather.apply(grad_output, group), None
# -----------------
# Helper functions.
# -----------------
def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableIdentity.apply(tensor, group)
def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllReduceSum.apply(tensor, group)
def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllGather.apply(tensor, group)
def differentiable_reduce_scatter_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableReduceScatterSum.apply(tensor, group)
from enum import Enum, auto
# TODO @thomasw21: python 3.11 introduces `StrEnum` which would've been great to use.
class TensorParallelLinearMode(Enum):
ALL_REDUCE = auto()
REDUCE_SCATTER = auto()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from torch.nn import functional as F
import nanotron.distributed as dist
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1
class _ShardedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(
ctx,
sharded_logits, # (batch_size, length, sharded_hidden_size)
target, # (batch_size, length)
group: dist.ProcessGroup,
):
# Maximum value along last dimension across all GPUs.
logits_max = torch.max(sharded_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)
# Get the shard's indices
sharded_hidden_size = sharded_logits.shape[-1]
rank = dist.get_rank(group)
start_index = rank * sharded_hidden_size
end_index = start_index + sharded_hidden_size
# Create a mask of valid ids (1 means it needs to be masked).
target_mask = (target < start_index) | (target >= end_index)
masked_target = target.clone() - start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, shard-size] and target to a 1-D tensor of size [*].
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
if predicted_logits_1d.is_contiguous():
predicted_logits_1d = predicted_logits_1d.clone()
else:
predicted_logits_1d = predicted_logits_1d.contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = sharded_logits
torch.exp(sharded_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss.view_as(target)
@staticmethod
def backward(ctx, grad_output):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, sharded_hidden_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None
def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtype: torch.dtype = None):
"""Helper function for the cross entropy."""
if dtype is not None:
# Cast input to specific dtype.
sharded_logits = sharded_logits.to(dtype=dtype)
return _ShardedCrossEntropy.apply(sharded_logits, target, group)
class _ColumnLinearAsyncCommunication(torch.autograd.Function):
"""Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215"""
@staticmethod
@assert_cuda_max_connections_set_to_1
def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
ctx.use_bias = bias is not None
ctx.tp_mode = tp_mode
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.tensor_shape = tensor.size()
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
group_size = group.size()
current_rank = dist.get_rank(group)
if group_size == 1:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
else:
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
# ctx.save_for_backward(tensor, weight)
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
gathered_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
gathered_tensor = MemoryBuffer().get(
"allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
)
else:
gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)
# Compute a shard of column_linear in the same time of AllGather
# We could compute the matmul of current holding shard and the current rank's weight
# We assume that rank 0 holds w0, rank 1 holds w1, etc.
# weights: w0 w1 w2 w3
# rank 0: X - - -
# rank 1: - X - -
# rank 2: - - X -
# rank 3: - - - X
# We call the corresponding shard of output "same_device_shard"
output_size = weight.shape[0]
gathered_output = torch.empty(
gathered_batch_size,
*intermediate_size,
output_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
before_shard, same_device_shard, after_shard = torch.split(
gathered_output,
split_size_or_sections=[
sharded_batch_size * current_rank,
sharded_batch_size,
sharded_batch_size * (group_size - current_rank - 1),
],
dim=0,
)
first_dims = math.prod([sharded_batch_size, *intermediate_size])
if bias is None:
torch.mm(
input=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
# Wait communication
handle.wait()
if tp_recompute_allgather:
ctx.save_for_backward(tensor, weight)
else:
ctx.save_for_backward(gathered_tensor, weight)
# Compute all the other shards that are obtained from AllGather
# weights: w0 w1 w2 w3
# rank 0: - X X X
# rank 1: X - X X
# rank 2: X X - X
# rank 3: X X X -
# As they could be not contiguous (r1 and r2) vertically as they are separated by "same_device_shard"
# We need to compute them separately, i.e. "before_shard" and "after_shard"
# For r0, "before_shard" is empty. For r3, "after_shard" is empty.
if before_shard.numel() > 0:
first_dims = math.prod(before_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
if after_shard.numel() > 0:
first_dims = math.prod(after_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
return gathered_output
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
@staticmethod
@assert_cuda_max_connections_set_to_1
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
tp_mode = ctx.tp_mode
handle1: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather:
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
if group.size() == 1:
total_tensor = tensor
else:
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = MemoryBuffer().get(
"allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the tensor gradient computation
total_tensor = unsharded_tensor
else:
total_tensor = tensor
grad_tensor = grad_output.matmul(weight)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)
handle2: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
if group.size() == 1:
sub_grad_tensor = grad_tensor
else:
sub_grad_tensor = torch.empty(
ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# Asynchronous all-reduce
handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
else:
raise ValueError()
grad_bias = grad_output.sum(dim=0) if use_bias else None
if handle1 is not None:
handle1.wait()
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = grad_output.t().matmul(total_tensor)
if handle2 is not None:
handle2.wait()
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return sub_grad_tensor, grad_weight, grad_bias, None, None, None
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
return grad_tensor, grad_weight, grad_bias, None, None, None
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""
@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):
# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)
# Get linear output.
out = F.linear(total_input, weight, bias)
return out
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim)
# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
if group.size() == 1:
sub_grad_input = grad_input
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
grad_input = grad_input.contiguous()
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None
return sub_grad_input, grad_weight, grad_bias, None, None
def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
input, weight, bias, group, tp_recompute_allgather
)
raise ValueError(f"Got unexpected mode: {tp_mode}.")
class _RowLinearAsyncCommunication(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, weight, bias, group, tp_mode):
assert (
tp_mode is TensorParallelLinearMode.REDUCE_SCATTER
), f"async communication in RowLinear only supports REDUCE_SCATTER, got {tp_mode}"
if group is None:
group = dist.distributed_c10d._get_default_group()
ctx.use_bias = bias is not None
ctx.group = group
out = F.linear(tensor, weight, bias)
if group.size() > 1:
out = differentiable_reduce_scatter_sum(out, group=group)
ctx.save_for_backward(tensor, weight)
return out
@staticmethod
@assert_cuda_max_connections_set_to_1
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
handle: Optional[dist.Work] = None
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = grad_output.shape
if group.size() == 1:
total_grad_output = grad_output
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_output = MemoryBuffer().get(
"allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
# total_grad_output: [b, s, h_out]
# weight: [h_out, h_in/n]
# total_grad_tensor: [b, s, h_in/n]
# grad_output: [b/n, s, h_out]
sharded_batch_size, *rest_size_grad_output = grad_output.shape
rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]]
if group.size() == 1:
total_grad_tensor = grad_output.matmul(weight)
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_tensor = torch.empty(
unsharded_batch_size,
*rest_size_grad_tensor,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
)
before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split(
total_grad_tensor,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# compute local shard
torch.mm(
input=grad_output.view(-1, grad_output.shape[-1]),
mat2=weight,
out=same_device_shard_grad_tensor.view(-1, weight.shape[1]),
)
if handle is not None:
handle.wait()
before_shard_grad_output, _, after_shard_grad_output = torch.split(
total_grad_output,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# before shard compute
if before_shard_grad_tensor.numel() > 0:
torch.mm(
input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]),
mat2=weight,
out=before_shard_grad_tensor.view(-1, weight.shape[1]),
)
# after shard compute
if after_shard_grad_tensor.numel() > 0:
torch.mm(
input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]),
mat2=weight,
out=after_shard_grad_tensor.view(-1, weight.shape[1]),
)
# Convert the tensor shapes to 2D for execution compatibility
tensor = tensor.contiguous()
tensor_first_dims, tensor_last_dim = tensor.shape[:-1], tensor.shape[-1]
tensor = tensor.view(math.prod(tensor_first_dims), tensor_last_dim)
# Convert the tensor shapes to 2D for execution compatibility
total_grad_output_first_dims, total_grad_output_last_dim = (
total_grad_output.shape[:-1],
total_grad_output.shape[-1],
)
total_grad_output = total_grad_output.view(math.prod(total_grad_output_first_dims), total_grad_output_last_dim)
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None
def row_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
):
if async_communication:
return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
out = F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=group)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
return out
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.sharded_parameters import (
SplitConfig,
create_sharded_parameter_from_config,
mark_all_parameters_in_module_as_sharded,
)
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.functional import (
column_linear,
row_linear,
)
from nanotron.parallel.tied_parameters import create_tied_parameter
class TensorParallelColumnLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
tp_recompute_allgather: bool = True,
):
self.pg = pg
self.world_size = pg.size()
assert out_features % self.world_size == 0
self.in_features = in_features
self.out_features = out_features // self.world_size
self.tp_recompute_allgather = tp_recompute_allgather
super().__init__(
in_features=self.in_features,
out_features=self.out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.mode = mode
self.async_communication = async_communication
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == out_features
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to out_features ({out_features})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)
mark_all_parameters_in_module_as_sharded(
self,
pg=self.pg,
split_config=split_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return column_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
)
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}"
class TensorParallelRowLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
self.world_size = pg.size()
assert in_features % self.world_size == 0
self.in_features = in_features // self.world_size
self.out_features = out_features
# No need to shard the bias term, only rank 0 would have it
bias = dist.get_rank(self.pg) == 0 and bias
super().__init__(
in_features=self.in_features,
out_features=self.out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.mode = mode
self.async_communication = async_communication
if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
raise ValueError("async_communication is not supported for ALL_REDUCE mode")
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == in_features
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to in_features ({in_features})"
split_config = SplitConfig(split_dim=1, contiguous_chunks=contiguous_chunks)
self._mark_all_parameters_in_module_as_sharded(split_config)
def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
for name, param in list(self.named_parameters()):
if name == "bias":
# `bias` only exists in rank 0 because it's not sharded
new_param = NanotronParameter(tensor=param)
else:
new_param = create_sharded_parameter_from_config(
parameter=param,
pg=self.pg,
split_config=split_config,
)
setattr(self, name, new_param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return row_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}"
class TiedLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
):
self.pg = pg
self.world_size = pg.size()
self.mode = mode
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
self._mark_all_parameters_in_module_as_tied()
def _mark_all_parameters_in_module_as_tied(self):
for name, param in list(self.named_parameters()):
new_param = create_tied_parameter(
parameter=param,
name=name,
global_ranks=tuple(sorted((get_global_rank(self.pg, i) for i in range(self.pg.size())))),
reduce_op=None if self.mode is TensorParallelLinearMode.ALL_REDUCE else dist.ReduceOp.SUM,
root_module=self,
)
setattr(self, name, new_param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = super().forward(x)
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
y = differentiable_identity(y, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
y = differentiable_all_gather(y, group=self.pg)
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")
return y
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
self.rank = dist.get_rank(self.pg)
self.world_size = pg.size()
self.original_num_embeddings = num_embeddings
# TODO @thomasw21: Fix and remove that constraint. Typically there's no reason to have such a constraint.
assert num_embeddings % self.world_size == 0
block_size = num_embeddings // self.world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.rank * block_size
self.max_id = (self.rank + 1) * block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
self.mode = mode
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == num_embeddings
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to num_embeddings ({num_embeddings})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)
mark_all_parameters_in_module_as_sharded(self, pg=self.pg, split_config=split_config)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.pg.size() > 1:
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input_ids, input_ids >= self.max_id)
# translate for [0, self.max_id - self.min_id[
masked_input = input_ids.clone() - self.min_id
# default all out of bounds values to `0`
masked_input[input_mask] = 0
else:
masked_input = input_ids
out = super().forward(masked_input)
if self.pg.size() > 1:
out = out * (~input_mask[..., None])
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=self.pg)
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")
return out
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_num_embeddings={self.original_num_embeddings}"
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from torch import nn
from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_parameter_and_parent_module
logger = logging.get_logger(__name__)
def create_tied_parameter(
parameter: nn.Parameter,
name: str,
global_ranks: Tuple[int, ...],
reduce_op: Optional[dist.ReduceOp],
root_module: nn.Module,
) -> NanotronParameter:
if not isinstance(parameter, NanotronParameter):
parameter = NanotronParameter(tensor=parameter)
parameter.mark_as_tied(name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module)
return parameter
def tie_parameters(
root_module: nn.Module,
ties: List[Tuple[str, Tuple[int, ...]]],
parallel_context: ParallelContext,
reduce_op: Optional[dist.ReduceOp],
):
"""
Tie parameters.
Within a single device, tied parameters are replaced with a single Parameter
Across devices, we add metadata to Parameters that require extra synchronization.
:param root_module: nn.Module
:param ties: List[Tuple[str, Tuple[int, ...]]]: a tie is (param_target, global_ranks)
:param parallel_context: ParallelContext
:return:
"""
if len(ties) < 1:
raise ValueError("Can't tie nothing")
# TODO @thomasw21: When we support Zero3 this isn't true anymore
dp_ranks = tuple(
sorted(
{
parallel_context.get_local_ranks(world_rank=global_rank)[2]
for _, global_ranks in ties
for global_rank in global_ranks
}
)
)
assert (
len(dp_ranks) == 1
), f"Tying weights has to happen with a replica of a model. Got the ranks from the following replicas: {dp_ranks}"
name = ties[0][0]
global_ranks = tuple(sorted(set().union(*(tie[1] for tie in ties))))
new_param = None
world_rank = dist.get_rank(parallel_context.world_pg)
for tie_target, tie_model_ranks in ties:
if world_rank not in tie_model_ranks:
continue
param, parent_module, param_name = get_parameter_and_parent_module(target=tie_target, root_module=root_module)
# If they are physically in the same device, then we tie them
if new_param is None:
new_param = create_tied_parameter(
parameter=param, name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module
)
# Re-assign it to the original name. We assign the raw tensor instead of the parameter since we moved it already.
setattr(parent_module, param_name, new_param)
def create_pg_for_tied_weights(root_module: nn.Module, parallel_context: ParallelContext):
"""Tied weights are tied across specific set of global ranks, we use this method to create process groups for each difference set of global ranks"""
group_ranks = {
param.get_tied_info().global_ranks
for name, param in root_module.named_parameters()
if isinstance(param, NanotronParameter) and param.is_tied
}
world_group_ranks = [None] * parallel_context.world_pg.size()
dist.all_gather_object(world_group_ranks, group_ranks, group=parallel_context.world_pg)
all_group_ranks = sorted(
set().union(*world_group_ranks),
)
for global_ranks in all_group_ranks:
if global_ranks not in parallel_context.world_ranks_to_pg:
parallel_context.world_ranks_to_pg[global_ranks] = dist.new_group(global_ranks)
def get_tied_id_to_param(
parameters: List[NanotronParameter], root_module: nn.Module
) -> Dict[Tuple[str, Tuple[int, ...]], NanotronParameter]:
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in root_module.named_modules()}
# Fix the root_model
module_id_to_prefix[id(root_module)] = ""
return {
(
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix),
param.get_tied_info().global_ranks, # TODO @nouamane: merge groups which tie the same parameter
): param
for param in parameters
if param.is_tied
}
def sync_tied_weights_gradients(
module: nn.Module, # TODO: NanotronModel
parallel_context: ParallelContext,
grad_accumulator: Optional[GradientAccumulator],
):
tied_id_to_param = get_tied_id_to_param(
parameters=[param for param in module.parameters() if param.requires_grad], root_module=module
)
# Only first and last rank should print the warning
for rank in [0, parallel_context.world_pg.size() - 1]:
log_rank(
f"[Debug Tied Weights] Syncing the following tied weights: {tied_id_to_param.keys()}",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_pg,
rank=rank,
)
# Group tensors to reduce by process groups
# Important to use ordered dict in order to be synchronized across all ranks
group_ranks_and_reduce_op_to_tensors_to_reduce = OrderedDict()
for (name, group_ranks), tied_param in sorted(tied_id_to_param.items(), key=lambda x: x[0]):
tied_info = tied_param.get_tied_info()
# Some weights don't require any syncing, because they are by design synchronised
if tied_info.reduce_op is None:
continue
if grad_accumulator is not None:
tied_grad = grad_accumulator.get_grad_buffer(name=name)
else:
tied_grad = tied_param.grad
log_rank(
f"Syncing tied weights {name} across ranks {group_ranks} ...",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_ranks_to_pg[group_ranks],
rank=0,
)
key = (group_ranks, tied_info.reduce_op)
if key in group_ranks_and_reduce_op_to_tensors_to_reduce:
group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)].append(tied_grad)
else:
group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)] = [tied_grad]
for (group_ranks, reduce_op), tensors in group_ranks_and_reduce_op_to_tensors_to_reduce.items():
dist.all_reduce_coalesced(tensors=tensors, op=reduce_op, group=parallel_context.world_ranks_to_pg[group_ranks])
import functools
import operator
import os
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton
class MemoryBuffer(metaclass=Singleton):
"""
Global memory buffer to store intermediate activations that need not to be cached for the backward pass.
"""
def __init__(self):
self.buffer = {}
def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
required_numel = functools.reduce(operator.mul, shape, 1)
if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
self.buffer[name, dtype] = torch.empty(
required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[name, dtype][:required_numel].view(shape)
def assert_cuda_max_connections_set_to_1(func):
flag_is_set_to_1 = None
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal flag_is_set_to_1
if flag_is_set_to_1 is None:
assert os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") == "1"
flag_is_set_to_1 = True
return func(*args, **kwargs)
return wrapper
def initial_sync(model: nn.Module, parallel_context: ParallelContext):
# Synchronize across dp: basic assumption
sorted_name_params = sorted(model.named_parameters(), key=lambda x: x[0])
for name, param in sorted_name_params:
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=parallel_context.dp_pg)
# Synchronize across tied weights: basic assumption
for (_, group_ranks), param in sorted(
get_tied_id_to_param(parameters=model.parameters(), root_module=model).items(), key=lambda x: x[0]
):
group = parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
import contextlib
import random
from dataclasses import dataclass
from typing import MutableMapping, Optional, Tuple
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
@dataclass
class RandomState:
random: Tuple[int, Tuple[int, ...], None]
numpy: Tuple[str, np.ndarray, int, int, float]
torch_cpu: torch.Tensor
torch_cuda: Optional[torch.Tensor]
def __eq__(self, other):
return (
isinstance(other, RandomState)
and all(v1 == v2 for v1, v2 in zip(self.random, other.random))
and all(
np.array_equal(v1, v2) if isinstance(v1, np.ndarray) else v1 == v2
for v1, v2 in zip(self.numpy, other.numpy)
)
and torch.equal(self.torch_cpu, other.torch_cpu)
and (
other.torch_cuda is None if self.torch_cuda is None else torch.equal(self.torch_cuda, other.torch_cuda)
)
)
class RandomStates(MutableMapping[str, RandomState]):
def __init__(self, dict: dict):
for key, value in dict.items():
self.check_type(key, value)
# TODO @thomasw21: We make a copy for safety measure.
self._dict = dict.copy()
@staticmethod
def check_type(key, value):
if not isinstance(key, str):
raise ValueError(f"Expected key to be of type str. Got {type(key)}")
if not isinstance(value, RandomState):
raise ValueError(f"Expected value to be of type `nanotron.dataclass.RandomState`. Got {type(value)}")
def __getitem__(self, item):
return self._dict[item]
def __iter__(self):
return self._dict.__iter__()
def __len__(self):
return len(self._dict)
def __delitem__(self, key):
raise ValueError("Can't delete a random states key")
def __setitem__(self, key, value):
if key not in self._dict:
raise ValueError("Can't add a new random states after initialisation")
self.check_type(key, value)
return self._dict.__setitem__(key, value)
def __eq__(self, other):
if not isinstance(other, RandomStates):
return False
return self._dict == other._dict
def set_random_seed(seed: int):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def set_random_state(random_state: RandomState):
random.setstate(random_state.random)
np.random.set_state(random_state.numpy)
torch.set_rng_state(random_state.torch_cpu)
if torch.cuda.is_available():
torch.cuda.set_rng_state(random_state.torch_cuda, "cuda")
else:
assert random_state.torch_cuda is None
def get_current_random_state():
"""Returns a snapshot of current random state"""
return RandomState(
random=random.getstate(),
numpy=np.random.get_state(),
torch_cpu=torch.random.get_rng_state(),
torch_cuda=torch.cuda.get_rng_state("cuda") if torch.cuda.is_available() else None,
)
@contextlib.contextmanager
def branch_random_state(random_states: RandomStates, key: str, enabled: bool):
"""
Context manager handling random state:
- upon entering: Stores current random state and set new random state defined by key.
- upon exiting: updates key in `random_states` to the new current random state, and set back the old one.
"""
if not enabled:
yield
return
old_random_state = get_current_random_state()
# Get the new state associated to the key
new_random_state = random_states[key]
set_random_state(new_random_state)
try:
yield
finally:
# Update state from parallel_context with the newest state
new_random_state = get_current_random_state()
random_states[key] = new_random_state
# Set the old state back
set_random_state(old_random_state)
def get_synced_random_state(
random_state: RandomState,
pg: ProcessGroup,
):
# We use rank 0 as a reference and broadcast random states from that rank to all the other ranks within a group in order to sync them
reference_rank = 0
if dist.get_rank(pg) == reference_rank:
random_states = [random_state]
else:
random_states = [None]
# TODO @thomasw21: broadcast tensor using `broadcast` in order not to use pickle
dist.broadcast_object_list(
random_states, src=dist.get_global_rank(pg, reference_rank), group=pg, device=torch.device("cuda")
)
new_random_state = random_states[0]
assert new_random_state is not None
return new_random_state
from .fsspec import check_path_is_local, fs_copy, fs_open
from .s3_mover import S3Mover
__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
import contextlib
from pathlib import Path
from typing import Tuple, Union
import fsspec
from fsspec.implementations import local
def get_filesystem_and_path(path: Path, storage_options=None) -> Tuple[fsspec.AbstractFileSystem, str]:
# Use supported filesystems in `fsspec`. If you need another one, please use `fsspec.registry.register_implementation`
# DO NOT USE `mode` argument as it adds a suffix `0.part` when using `mode="w"`.
fs, _, paths = fsspec.core.get_fs_token_paths(str(path), storage_options=storage_options)
assert len(paths) == 1
return fs, paths[0]
@contextlib.contextmanager
def fs_open(
file: Union[str, Path],
mode="r",
):
# TODO @thomasw21: pass storage options.
fs, path = get_filesystem_and_path(file)
with fs.open(path, mode=mode) as f:
yield f
def fs_copy(
input_file: Union[str, Path],
output_file: Union[str, Path],
):
"""Copy file from input to output (possibly on s3/other fs)"""
with fs_open(input_file, mode="rb") as fi, fs_open(output_file, mode="wb") as fo:
fo.write(fi.read())
def check_path_is_local(path: Path, storage_options=None) -> bool:
return isinstance(get_filesystem_and_path(path=path, storage_options=storage_options)[0], local.LocalFileSystem)
import glob
import json
import os
import subprocess
import time
from datetime import datetime
from enum import Enum
from typing import Optional, Union
import torch
from datasets.download.streaming_download_manager import xPath
from filelock import FileLock, Timeout
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import human_format
logger = logging.get_logger(__name__)
class S3Mover:
# TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
"""Take care of uploading a checkpoint to S3 in the background and remove it from the disk.
Args:
local_path: Path to the checkpoints on the local disk
s3_path: Path to the checkpoints on S3
remove_after_upload: If True, remove the checkpoint from the disk after uploading it to S3
s5cmd_numworkers: Number of workers to use for the s5cmd command
s5cmd_concurrency: Concurrency to use for the s5cmd command
s5cmd_path: Path to the s5cmd command
dummy: If True, don't actually upload/remove/etc anything. Useful for simpler multi-processing node and only uploading from one process.
Usage:
# Create a mover - use dummy=True for all the process that shouldn't do anything (e.g. all but one per node)
mover = S3Mover(local_path=/scratch/my-checkpoints,
s3_path=s3://my-bucket/my-checkpoints,
remove_after_upload=True,
s5cmd_numworkers=96,
s5cmd_concurrency=10,
s5cmd_path=/admin/user/my/bin/s5cmd,
dummy=False)
while training:
# from times to times update the state
mover_status = mover.update()
...
# When saving a checkpoint, check if the previous checkpoint has been uploaded and removed
# in a distributed setting
"""
class S3MoverState(Enum):
IDLE = "IDLE"
UPLOADING = "UPLOADING"
DOWNLOADING = "DOWNLOADING"
REMOVING_CHECKPOINT = "REMOVING_CHECKPOINT"
class DummyPopen:
def __init__(self, *args, **kwargs):
pass
def poll(self):
return 0
def communicate(self):
return ("", "")
def __init__(
self,
local_path: xPath,
s3_path: xPath,
post_upload_callback: Optional[callable] = None,
remove_after_upload: Optional[bool] = True,
s5cmd_numworkers: Optional[int] = None,
s5cmd_concurrency: Optional[int] = None,
s5cmd_path: Optional[str] = None,
s5cmd_credentials: Optional[str] = None,
clean_up_local_on_start: bool = False,
dummy: bool = False,
s3_region: str = "us-east-1",
):
self.process: Optional[Union[subprocess.Popen, S3Mover.DummyPopen]] = None
self.remove_after_upload = remove_after_upload
self.s5cmd_numworkers = s5cmd_numworkers
self.s5cmd_concurrency = s5cmd_concurrency
self.s5cmd_path = s5cmd_path if s5cmd_path is not None else "s5cmd"
self.s5cmd_credentials = s5cmd_credentials
self.lock_file = None
self.dummy = dummy
self.s3_region = s3_region
self.post_upload_callback = post_upload_callback
self.post_upload_callback_outputs = None
local_path = str(local_path)
if not local_path.startswith("/scratch/"):
self._warning(f"The local path is not on the scratch drive: {local_path}")
if not local_path.endswith("/"):
local_path += "/"
s3_path = str(s3_path)
if not s3_path.endswith("/"):
s3_path += "/"
self.local_path = local_path
self.s3_path = s3_path
s3_bucket, s3_prefix = s3_path.replace("s3://", "").split("/", maxsplit=1)
self.s3_path_direct_link = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?region={self.s3_region}&prefix={s3_prefix}&showversions=false"
self._reset_state()
if clean_up_local_on_start:
self._start_removing()
def _warning(self, message):
if self.dummy:
return
logger.warning(message)
def _info(self, message):
if self.dummy:
return
logger.info(message)
def _reset_state(self):
self.state = self.S3MoverState.IDLE
self.num_uploaded_files = 0
if self.lock_file is not None:
self._release_lock()
self.lock_file = None
self.stdout = ""
self.start_time: datetime = None
self.cmd = ""
def _popen(self, cmd: list):
self.stdout = ""
self.start_time = datetime.now()
self.cmd = cmd
if self.dummy:
return self.DummyPopen(cmd)
else:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
os.set_blocking(process.stdout.fileno(), False)
return process
def _acquire_lock(self, file_path: str) -> bool:
if self.dummy:
return True
if file_path.endswith("/"):
lock_file_path = file_path[:-1] + ".lock"
else:
lock_file_path = file_path + ".lock"
self.lock_file = FileLock(lock_file_path)
try:
self.lock_file.acquire(timeout=1)
except Timeout:
message = f"[S3] The checkpoint files {lock_file_path} are currently locked by another process. "
self._warning(message)
return False
return True
def get_state_as_int(self) -> int:
"""Return the state as an int"""
if self.state == self.S3MoverState.IDLE:
return 0
elif self.state == self.S3MoverState.UPLOADING:
return 1
elif self.state == self.S3MoverState.DOWNLOADING:
return 2
elif self.state == self.S3MoverState.REMOVING_CHECKPOINT:
return 3
else:
return -1
def _release_lock(self):
if self.dummy:
return
if self.lock_file is not None and self.lock_file.is_locked:
self.lock_file.release()
def get_current_stdout(self) -> str:
"""Return the current stdout of the process if any"""
if self.process is None or isinstance(self.process, self.DummyPopen):
return ""
try:
stdout = self.process.stdout.read()
except ValueError:
stdout = "" # The buffer is already closed: "ValueError: read of closed file"
if stdout:
self.stdout += stdout.decode()
return self.stdout
def wait_for_completion(self):
while self.state != self.S3MoverState.IDLE:
_ = self.update()
time.sleep(0.5)
def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None):
"""Wait for the previous checkpoint to be fully uploaded and removed in a distributed setting.
Will wait for all process to be ready
"""
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
dist.barrier()
all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list)
if all_saved != group.size() and self.state != self.S3MoverState.IDLE:
self._warning(
f"Waiting previous checkpoint saving is finished - S3Mover {dist.get_rank(group)} still in {self.state} state.",
)
while all_saved != group.size():
stdout = self.get_current_stdout()
stdout_lines = [lst for lst in stdout.split("\n") if lst]
if self.state != self.S3MoverState.IDLE:
self._warning(
f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}",
)
# sync all our saves on NCCL we could do a dist barrier later but this helps us not losing NCCL connections down the line
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
dist.barrier()
all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list)
time.sleep(1)
def is_previous_save_finished(self) -> bool:
"""Return True if a potential previous checkpoint has been fully uploaded to S3
and removed from the drive
"""
self.update()
return self.state == self.S3MoverState.IDLE
def _start_downloading(self, sub_folder: Optional[str] = None) -> (bool, str):
self._warning(
f"[S3] Downloading checkpoint in background from {self.s3_path} to {self.local_path} (direct link: {self.s3_path_direct_link})"
)
cmd = [self.s5cmd_path, "--json"]
if self.s5cmd_credentials is not None:
cmd += ["--credentials-file", self.s5cmd_credentials]
if self.s5cmd_numworkers is not None:
cmd += ["--numworkers", str(self.s5cmd_numworkers)]
cmd += ["cp"]
if self.s5cmd_concurrency is not None:
cmd += ["--concurrency", str(self.s5cmd_concurrency)]
cmd += [self.s3_path + "*", self.local_path]
self.process = self._popen(cmd)
self.state = self.S3MoverState.DOWNLOADING
return True
def _post_downloading(self) -> bool:
self.get_current_stdout()
s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i]
total_files = len([i for i in s5cmd_results if i["success"]])
total_not_downloaded_files = len([i for i in s5cmd_results if not i["success"]])
if total_not_downloaded_files == 0:
all_upload = "all files"
success = True
else:
all_upload = "not all files"
success = False
total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"])
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully downloaded {total_files} files for a total of {human_format(total_size)}B in {total_time}"
f"sec ({all_upload}) from S3 at {self.s3_path} to {self.local_path}"
f"(direct link: {self.s3_path_direct_link})"
)
return success
def _start_uploading(
self,
) -> (bool, str):
# Get a file lock on the first file
local_files = glob.glob(self.full_local_path + "/**/*.*", recursive=True)
locked = self._acquire_lock(local_files[0])
if not locked:
return False
if not os.path.exists(self.full_local_path):
message = f"[S3] Checkpoint {self.full_local_path} does not exist, cannot upload to S3"
self._warning(message)
return False
self._warning(
f"[S3] Uploading checkpoint in background from {self.full_local_path} to {self.full_s3_path} (direct link: {self.s3_path_direct_link})"
)
cmd = [self.s5cmd_path, "--json"]
if self.s5cmd_credentials is not None:
cmd += ["--credentials-file", self.s5cmd_credentials]
if self.s5cmd_numworkers is not None:
cmd += ["--numworkers", str(self.s5cmd_numworkers)]
cmd += ["cp", "--exclude", "*.lock", "--exclude", "*.lock.*"]
if self.s5cmd_concurrency is not None:
cmd += ["--concurrency", str(self.s5cmd_concurrency)]
cmd += [self.full_local_path, self.full_s3_path]
self.process = self._popen(cmd)
self.state = self.S3MoverState.UPLOADING
return True
def _post_uploading(self) -> bool:
self.get_current_stdout()
s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i]
local_files = glob.glob(self.full_local_path + "/**/*.?*", recursive=True)
total_files = len([i for i in s5cmd_results if i["success"]])
self.num_uploaded_files = total_files
if len(local_files) == total_files:
all_upload = "all files"
success = True
else:
all_upload = f"not all files: {len(local_files)} out of {total_files}"
success = False
total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"])
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully uploaded {total_files} files for a total of {human_format(total_size)}B in {total_time} sec"
f"({all_upload}) from {self.full_local_path} to S3 at {self.full_s3_path} "
f"(direct link: {self.s3_path_direct_link})"
)
if self.post_upload_callback:
self.post_upload_callback_outputs = self.post_upload_callback(uploaded_files=s5cmd_results)
self._release_lock()
return success
def _start_removing(self) -> (bool, str):
top_dir_in_local_checkpoint = [dir for dir in glob.glob(self.local_path + "/*") if os.path.isdir(dir)]
names_dir = [os.path.basename(dir) for dir in top_dir_in_local_checkpoint]
if len(names_dir) == 0:
# If the local is already empty or if we have already started duplicating in another process we skip with a noop
self._warning("[S3] Local checkpoint empty. skipping removal")
cmd = ["echo", "'skipping'"]
self.process = self._popen(cmd)
self.state = self.S3MoverState.REMOVING_CHECKPOINT
return True
self._warning(f"[S3] Removing checkpoint in background: {names_dir}")
locked = self._acquire_lock(top_dir_in_local_checkpoint[0])
if not locked:
return False
cmd = ["rm", "-rfv"] + top_dir_in_local_checkpoint
self.process = self._popen(cmd)
self.state = self.S3MoverState.REMOVING_CHECKPOINT
return True
def _post_removing(self) -> bool:
self.get_current_stdout()
local_files = [
loc_f
for loc_f in self.stdout.split("\n")
if "directory" not in loc_f.lower() and loc_f and ".lock" not in loc_f
]
if len(local_files) == self.num_uploaded_files:
all_removed = "all files"
success = True
else:
all_removed = "not all files"
success = False
self._release_lock()
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully removed {len(local_files)} local files ({all_removed}) from {self.local_path} (uploaded to {self.s3_path_direct_link}) in {total_time}"
)
return success
def update(self) -> (str, str):
"""Update the state of the mover: UPLOADING => REMOVING_DUPLICATED => DUPLICATING => REMOVING_CHECKPOINT => IDLE
Returns:
(str, str): The state and the stdout of the process if any
"""
if self.process is None:
self._reset_state()
return self.state, self.stdout
return_code = self.process.poll()
if return_code is None:
# Still running
return self.state, self.stdout
if return_code != 0:
self.get_current_stdout()
self._warning(
f"[S3] Error running command {self.cmd} during process {self.state.value}, "
f"return code {return_code}, return message {self.stdout}"
)
return self.state, self.stdout
if self.state == self.S3MoverState.DOWNLOADING:
self._post_downloading()
self._reset_state()
elif self.state == self.S3MoverState.UPLOADING:
self._post_uploading()
if self.remove_after_upload:
self._start_removing()
else:
self._reset_state()
elif self.state == self.S3MoverState.REMOVING_CHECKPOINT:
self._post_removing()
self._reset_state()
return self.state.value, self.stdout
def start_uploading(self, sub_folder=None):
"""Start uploading last saved checkpoint to S3 in the background.
After running this method, you should call regularly `update` to update the
state to duplicating and then removing.
For a blocking upload, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method.
"""
self.update()
if self.state != self.S3MoverState.IDLE:
message = "[S3] Cannot move to S3 as the previous checkpoint has not been uploaded and removed"
self._warning(message)
return False
self.full_local_path = self.local_path + (f"/{sub_folder}" if sub_folder else "")
self.full_s3_path = self.s3_path + (f"/{sub_folder}" if sub_folder else "")
return self._start_uploading()
def start_downloading(self):
"""Start downloading a checkpoint from S3 in the background.
After running this method, you should call regularly `update` to update the
state.
For a blocking download, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method.
"""
self.update()
if self.state != self.S3MoverState.IDLE:
message = f"[S3] Cannot download from S3 as the state is not IDLE but {self.state.value}"
self._warning(message)
return False
return self._start_downloading()
from contextlib import contextmanager
from typing import Callable, Optional
import torch
from nanotron import distributed as dist
from nanotron import logging, optim
from nanotron.config import Config
from nanotron.logging import get_logger, log_rank
from nanotron.models import NanotronModel
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
logger = get_logger(__name__)
def assert_tensor_synced_across_pg(
tensor: torch.Tensor,
pg: dist.ProcessGroup,
msg: Optional[Callable[[str], str]] = None,
reference_rank: int = 0,
):
"""Assert that `tensor` is synced across `pg` with reference rank. Note that this always passes for reference rank"""
if dist.get_rank(pg) == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=dist.get_global_rank(group=pg, group_rank=reference_rank),
group=pg,
)
# TODO @nouamane: Getting Greatest absolute difference: 4.6e-10 at large scale when syncing tied weights
torch.testing.assert_close(tensor, reference_tensor, msg=msg)
# TODO @nouamanetazi: remove this with SANITY_CHECKS
@contextmanager
def assert_fail_except_rank_with(exception_class, rank_exception, pg):
try:
yield
except exception_class:
if rank_exception == dist.get_rank(pg):
raise AssertionError(f"Expected rank {rank_exception} to not raise {exception_class}.")
else:
return
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got {type(e)} instead:\n{e}")
if dist.get_rank(pg) != rank_exception:
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
def before_tbi_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that the model params are synchronized across dp
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synchronized across DP {err}",
)
# SANITY CHECK: Tied weights are synchronized
tied_params_list = sorted(
get_tied_id_to_param(
parameters=unwrapped_model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
)
for (name, group_ranks), param in tied_params_list:
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=param,
pg=group,
msg=lambda err: f"[Before train] Tied weights {name} are not synchronized. {err}",
)
# SANITY CHECK: Check that model grads are zeroed or None
for name, param in unwrapped_model.named_parameters():
if param.grad is not None:
torch.testing.assert_close(
param.grad,
torch.zeros_like(param.grad),
atol=0,
rtol=0,
msg="Model half precision grads must be zeroed or None in first accumulation step.",
)
# SANITY CHECK: Check that the grad accumulator buffers are ready for DDP
if grad_accumulator is not None:
for _, elt in grad_accumulator.fp32_grad_buffers.items():
fp32_grad_buffer = elt["fp32_grad"]
torch.testing.assert_close(
fp32_grad_buffer,
torch.zeros_like(fp32_grad_buffer),
atol=0,
rtol=0,
msg="Grad accumulator buffers must be zeroed in first accumulation step.",
)
# TODO: add checks for memory contiguousness
# SANITY CHECK: Check that optimizer's lr is synchronized with lr_scheduler
for i, group in enumerate(lr_scheduler.optimizer.param_groups):
assert (
group["lr"] == lr_scheduler.get_last_lr()[i]
), f"Optimizer and LR scheduler are not in sync. Got {group['lr']} and {lr_scheduler.get_last_lr()[i]}"
break
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_tbi_sanity_checks()
def after_tbi_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that gradient flow on the entire model
# SANITY CHECK: Check that all parameters that required gradients, have actually a gradient
# SANITY CHECK: Check for nan/inf
for name, param in unwrapped_model.named_parameters():
if not param.requires_grad:
continue
if param.is_tied:
tied_info = param.get_tied_info()
name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=unwrapped_model.module_id_to_prefix
)
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
if torch.isnan(grad).any() or torch.isinf(grad).any():
raise ValueError("Gradient is nan or inf")
if grad is None:
log_rank(
f"Process rank { dist.get_rank(parallel_context.world_pg)}/{parallel_context.world_pg.size()}: {name} is missing gradient",
logger=logger,
level=logging.ERROR,
)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.after_tbi_sanity_checks()
def before_optim_step_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
optimizer: optim.BaseOptimizer,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Test tied weights gradients are synchronized
for (name, group_ranks), param in sorted(
get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(),
key=lambda x: x[0],
):
if not param.requires_grad:
continue
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
assert grad is not None, f"Grad is None for {name}"
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=grad,
pg=group,
msg=lambda err: f"[Before optimizer step] Tied weights grads for {name} are not synchronized. {err}",
)
# SANITY CHECK: Test gradients are synchronized across DP
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
if not param.requires_grad:
continue
if param.is_tied:
tied_info = param.get_tied_info()
name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=unwrapped_model.module_id_to_prefix
)
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
assert grad is not None, f"Grad is None for {name}"
assert_tensor_synced_across_pg(
tensor=grad,
pg=parallel_context.dp_pg,
msg=lambda err: f"[Before optimizer step] weights grads for {name} are not synchronized across DP. {err}",
)
# SANITY CHECK: Check that the model params are synchronized across dp
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synchronized across DP {err}",
)
# SANITY CHECK: Tied weights are synchronized
tied_params_list = sorted(
get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(),
key=lambda x: x[0],
)
for (name, group_ranks), param in tied_params_list:
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=param,
pg=group,
msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}",
)
# SANITY CHECK: Check that optimizer states are synchronized across DP
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()
def after_optim_step_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that gradients is cleared
for name, param in unwrapped_model.named_parameters():
if not param.requires_grad:
continue
if param.grad is not None:
log_rank(
f"Process rank { dist.get_rank(parallel_context.world_pg)}/{parallel_context.world_pg.size()}: {name} still has gradient despite having ran the optimizer",
logger=logger,
level=logging.ERROR,
)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.after_optim_step_sanity_checks()
def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup):
for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
if name == "step":
continue
assert_tensor_synced_across_pg(
tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}"
)
import math
from abc import abstractmethod
from enum import Enum, auto
from typing import Dict
from nanotron.config import ModelArgs
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from torch import nn
from torch.nn import init
class ParametrizationMethod(Enum):
STANDARD = auto()
SPECTRAL_MUP = auto()
class Parametrizator:
def __init__(self, config: ModelArgs):
self.config = config
def parametrize(self, param_name: str, module: nn.Module):
if not isinstance(module, tuple(self.MODULE_TO_PARAMETRIZE.keys())):
raise Exception(f"Parameter {param_name} was not initialized")
return self.MODULE_TO_PARAMETRIZE[type(module)](param_name, module)
class StandardParametrizator(Parametrizator):
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._parametrize_column_linear,
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}
self.std = config.init_method.std
self.num_layers = config.model_config.num_hidden_layers
def _parametrize_column_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_row_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
std = self.std / math.sqrt(2 * self.num_layers)
init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
class SpectralMupParametrizator(Parametrizator):
"""
A Spectral Condition for Feature Learning by Greg Yang, et al.
https://arxiv.org/abs/2310.17813
"""
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._parametrize_mup_weight,
TensorParallelRowLinear: self._parametrize_mup_weight,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}
self.std = 1.0
@staticmethod
def _compute_spectral_std(std: float, fan_in: int, fan_out: int):
"""
Parametrization 1 (Spectral parametrization)
Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al.
σₗ = Θ(1/√nₗ₋₁ min{1, √(nₗ/nₗ₋₁)})
"""
return (std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
def _parametrize_mup_weight(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
data = module.weight if param_name == "weight" else module.bias
fan_in, fan_out = init._calculate_fan_in_and_fan_out(data)
world_size = module.world_size
if isinstance(module, TensorParallelColumnLinear):
fan_out = fan_out * world_size
elif isinstance(module, TensorParallelRowLinear):
fan_in = fan_in * world_size
else:
raise ValueError(f"Unknown module {module}")
std = SpectralMupParametrizator._compute_spectral_std(std=self.std, fan_in=fan_in, fan_out=fan_out)
init.normal_(data, mean=0.0, std=std)
def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
# NOTE: you're free to change the initialization of layer norm
# as it's not a part of µTransfer
if "weight" == param_name:
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]
# NOTE: you're free to change the initialization of input embedding/lm head
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
class LearningRateForParametrizator:
def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]):
self.lr = lr
self.names_to_modules = names_to_modules
@abstractmethod
def get_lr(self, param_name: str, module: nn.Module) -> float:
raise NotImplementedError
class LearningRateForSP(LearningRateForParametrizator):
"""All parameters get the same learning rate."""
def get_lr(self, param_name: str, param: nn.Module) -> float:
return self.lr
class LearningRateForSpectralMup(LearningRateForParametrizator):
"""
A Spectral Condition for Feature Learning by Greg Yang, et al.
NOTE: each parameter gets a custom learning rate based on its fan-in and fan-out.
"""
def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]):
super().__init__(lr, names_to_modules)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._get_mup_lr,
TensorParallelRowLinear: self._get_mup_lr,
TritonRMSNorm: self._get_global_lr,
TensorParallelEmbedding: self._get_global_lr,
}
def _get_mup_lr(self, param: nn.Parameter, module: nn.Module):
"""
Parametrization 1 (Spectral parametrization)
Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al.
ηₗ = Θ(nₗ/nₗ₋₁)
"""
fan_in, fan_out = init._calculate_fan_in_and_fan_out(param)
world_size = module.world_size
if isinstance(module, TensorParallelColumnLinear):
fan_out = fan_out * world_size
elif isinstance(module, TensorParallelRowLinear):
fan_in = fan_in * world_size
else:
raise ValueError(f"Unknown module {module}")
return self.lr * (fan_out / fan_in)
def _get_global_lr(self, param: nn.Parameter, module: nn.Module) -> float:
return self.lr
def get_lr(self, param_name: str, param: nn.Parameter) -> float:
"""Return the learning rate for the given parameter."""
# NOTE: param_name should be like 'model.token_position_embeddings.pp_block.token_embedding.weight'
# since names_to_modules map module_name to module
# so we remove the .weight and .bias from param_name to get the module_name
module_name = param_name.rsplit(".", 1)[0]
module = self.names_to_modules[module_name]
return self.MODULE_TO_PARAMETRIZE[type(module)](param, module)
# flake8: noqa
from nanotron.serialize.main import *
from nanotron.serialize.optimizer import *
from nanotron.serialize.random import *
from nanotron.serialize.weights import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment