Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel.pipeline_parallel.functional import (
recv_from_pipeline_state_buffer,
send_to_pipeline_state_buffer,
)
from nanotron.parallel.pipeline_parallel.p2p import P2P, BatchTensorSendRecvState
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState, PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
class PipelineBlock(nn.Module):
"""Most granular pipeline block, ie within this module, everything will be part of a single rank, ie the entire computation within this block will happen on a specific device.
Current limitations:
- PipelineBlocks have to wrap a method/function/module that outputs a Dict[str, torch.Tensor]
Some considerations:
- In the literature, authors often refer to pipeline stages as a granularity block. Our notion is more granular. A pipeline stage is list of contiguous (in the forward sense) of pipeline blocks.
All PipelineBlock definition exist in each rank, they are just instantiated/built on a single rank per pipeline parallel process group.
"""
def __init__(
self,
p2p: P2P,
module_builder: Callable[..., Callable[..., Union[torch.Tensor, Dict[str, torch.Tensor]]]],
module_kwargs: Dict[str, Any],
module_input_keys: Set[str],
module_output_keys: Set[str],
):
super().__init__()
# Module follows a restrictive API: module.forward return a `Dict[str, torch.Tensor]`
self.p2p = p2p
# None signifies that we don't use specific pipeline engine and just run typical torch forward/backward pass
self.pipeline_state: Optional[PipelineBatchState] = None
self.module_builder = module_builder
self.module_kwargs = module_kwargs
self.module_input_keys = set(module_input_keys)
self.module_output_keys = set(module_output_keys)
def build_and_set_rank(self, pp_rank: int):
"""This method is used to define on which rank computation is going to happen"""
assert pp_rank < self.p2p.pg.size()
self.rank = pp_rank
if pp_rank == dist.get_rank(self.p2p.pg):
# Instantiate the module
self.pp_block = self.module_builder(**self.module_kwargs)
def extra_repr(self) -> str:
return f"pp_rank={self.rank}" if hasattr(self, "rank") else ""
def set_pipeline_state(self, pipeline_state: Optional[PipelineBatchState]):
self.pipeline_state = pipeline_state
def forward(self, **kwargs):
"""Forward pass
We use a mechanism using TensorPointers to pass Tensors around
All non Tensor object or TensorPointers are considered pass-through, they are never meant to be communicated cross process
:param kwargs: Dict[str, Union[TensorPointer, torch.Tensor, Any]]
:return: Dict[str, Union[TensorPointer, torch.Tensor, Any]
"""
assert self.module_input_keys == set(
kwargs.keys()
), f"Expected {self.module_input_keys}, got {set(kwargs.keys())}"
sorted_kwargs = sorted(kwargs.items(), key=get_sort_key(dist.get_rank(self.p2p.pg)))
# Is the current rank is not the one running the compute
if dist.get_rank(self.p2p.pg) != self.rank:
# TODO(kunhao): A better design is to pop this up for both if else branches.
batch_send_recv = BatchTensorSendRecvState(self.p2p)
# Send activations from other devices to local rank
for name, tensor in sorted_kwargs:
if isinstance(tensor, TensorPointer):
# Current rank is neither the rank holding the data nor the rank responsible for computing block
continue
else:
assert isinstance(tensor, torch.Tensor)
# We need to send the tensor to the rank that actually runs the compute
if self.pipeline_state is not None:
send_to_pipeline_state_buffer(
tensor,
to_rank=self.rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
if tensor.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)
batch_send_recv.add_send(tensor=tensor, to_rank=self.rank)
batch_send_recv.flush()
# Return that the outputs are all in the rank responsible for computing block
# TODO @thomasw21: Figure out a way to build dummy_input in a generic sense, and remove the necessity to have Dict[str, torch.Tensor] as output
return {k: TensorPointer(group_rank=self.rank) for k in self.module_output_keys}
# Recv activations from other devices to local rank
new_kwargs: Dict[str, torch.Tensor] = {}
name_to_recv_id = {}
batch_send_recv = BatchTensorSendRecvState(self.p2p)
for name, tensor in sorted_kwargs:
if isinstance(tensor, TensorPointer):
# Current rank is the one running the compute, we need to query the tensor
# new_kwargs[name] = recv_tensor(from_rank=tensor.group_rank, p2p=self.p2p)
# This assumes that prior communication was already done
# In case of interleaved 1f1b, if this is the second model chunk, then we need to send the previous activations before receiving the current activations
if isinstance(self.pipeline_state, PipelineTrainBatchState):
for _ in range(len(self.pipeline_state.microbatches_activations_to_send)):
send_activation = self.pipeline_state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
if self.pipeline_state is not None:
new_kwargs[name] = recv_from_pipeline_state_buffer(
from_rank=tensor.group_rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
# We don't store result in a buffer
recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank)
name_to_recv_id[name] = recv_id
else:
new_kwargs[name] = tensor
# Run receiving communications
recv_tensors = batch_send_recv.flush()
assert len(recv_tensors) == len(name_to_recv_id)
for name, recv_id in name_to_recv_id.items():
assert name not in new_kwargs
new_tensor = recv_tensors[recv_id]
if new_tensor.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried receiving a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)
new_kwargs[name] = new_tensor
output = self.pp_block(**new_kwargs)
# Helper for functions that return tensors
if isinstance(output, torch.Tensor):
assert len(self.module_output_keys) == 1
output = {next(iter(self.module_output_keys)): output}
assert isinstance(output, dict), "Modules within a Pipeline Block have to return a Dict[str, torch.Tensor]"
assert self.module_output_keys == set(
output.keys()
), f"Expected {self.module_output_keys}, got {set(output.keys())}"
return output
def get_min_max_rank(module: torch.nn.Module) -> Tuple[int, int]:
"""Finds min and max PP ranks of the underlying PipelineBlocks"""
ranks = [module.rank for module in module.modules() if isinstance(module, PipelineBlock)]
return min(ranks), max(ranks)
def get_sort_key(current_rank: int):
"""The idea is to free earlier ranks earlier."""
def sort_key(elt: Tuple[str, Union[torch.Tensor, TensorPointer]]):
name, tensor = elt
rank: int
if isinstance(tensor, TensorPointer):
rank = tensor.group_rank
else:
rank = current_rank
return rank, name
return sort_key
from contextlib import contextmanager
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState
from torch import nn as torch_nn
@contextmanager
def attach_pipeline_state_to_model(model: torch_nn.Module, pipeline_state: PipelineBatchState):
"""Attach the pipeline state to all the PipelineBlocks within `model`"""
old_pipeline_states = []
# Set new
for name, module in model.named_modules():
if not isinstance(module, PipelineBlock):
continue
old_pipeline_state = module.pipeline_state
assert old_pipeline_state is None, "We never replace an old pipeline engine, we just set one when there's none"
old_pipeline_states.append((old_pipeline_state, module))
module.set_pipeline_state(pipeline_state)
try:
yield
finally:
for old_pipeline_state, module in old_pipeline_states:
module.set_pipeline_state(old_pipeline_state)
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional, Union
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel
logger = logging.get_logger(__name__)
class PipelineEngine(ABC):
def __init__(self):
self.nb_microbatches: Optional[int] = None
pass
def forward(
self,
context: ContextManagers,
state: PipelineTrainBatchState,
micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]],
model: torch_nn.Module,
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Increment the number of backwards
state.nb_forwards += 1
log_rank(
f"Forward micro batch id: {state.nb_forwards}",
logger=logger,
level=logging.DEBUG,
)
# IMPORTANT as it's basically the context manager storing all the intermediary activations
state.new_micro_batch_forward()
with context:
output = model(**micro_batch)
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# We normalize our loss
if not isinstance(output["loss"], TensorPointer):
output["loss"] = output["loss"] / self.nb_microbatches
# Add output as activations that require backward pass
if not isinstance(output["loss"], TensorPointer):
assert output["loss"].requires_grad
state.register_activation_requiring_backward(output["loss"])
return output
@staticmethod
def _get_fwd_context(model: torch_nn.Module):
is_ddp = isinstance(model, DistributedDataParallel)
# We never to trigger a DDP sync in the next backward pass
context = ContextManagers([model.no_sync()] if is_ddp else [])
return context
def backward(
self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator]
):
# Increment the number of backwards
state.nb_backwards += 1
log_rank(
f"Backward micro batch id: {state.nb_forwards}",
logger=logger,
level=logging.DEBUG,
)
# Go backward entirely
activations = state.pop_last_activations_requiring_backward()
if len(activations) == 0:
return
with context:
if grad_accumulator is None:
sum(activations).backward()
else:
grad_accumulator.backward(sum(activations))
# TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
# with context:
# if grad_accumulator is None:
# for activation in reversed(activations): #TODO @nouamane: need to bwd only 2nd chunk
# activation.backward()
# else:
# for activation in reversed(activations):
# grad_accumulator.backward(activation)
def _get_bwd_context(
self,
model: torch_nn.Module,
nb_backwards: int,
grad_accumulator: Optional[GradientAccumulator],
):
assert (
self.nb_microbatches is not None
), "You must call `train_batch_iter` first and set `self.nb_microbatches`"
is_ddp = isinstance(model, DistributedDataParallel)
context_list = []
if is_ddp:
if grad_accumulator is not None and nb_backwards < self.nb_microbatches - 1:
context_list.append(grad_accumulator.no_sync()) # Prevents grad accumulator from syncing
if nb_backwards == self.nb_microbatches - 1:
# Triggers DDP to sync gradients in the next backward pass
context_list.append(ddp_trigger_sync_in_bwd(model_ddp=model))
context = ContextManagers(context_list)
return context
@abstractmethod
def train_batch_iter(
self,
model: torch_nn.Module,
pg: ProcessGroup,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
grad_accumulator: Optional[GradientAccumulator],
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
"""If model returns tensor, we use it as a loss to backpropagate. If model returns a dict, we assume that the key "loss" is the loss to backpropagate."""
...
@torch.inference_mode()
def validate_batch_iter(
self,
model: torch_nn.Module,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState() # TODO: do i need state?
self.nb_microbatches = nb_microbatches
outputs = []
with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# All forward
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
return outputs
class AllForwardAllBackwardPipelineEngine(PipelineEngine):
def __init__(self):
super().__init__()
def train_batch_iter(
self,
model: torch_nn.Module,
pg: ProcessGroup,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
grad_accumulator: Optional[GradientAccumulator],
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState()
self.nb_microbatches = nb_microbatches
outputs = []
with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# All forward
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
# All backward
for _ in range(len(state.microbatches_activations_requiring_backward)):
context = self._get_bwd_context(
model=model,
nb_backwards=state.nb_backwards,
grad_accumulator=grad_accumulator,
)
self.backward(context=context, state=state, grad_accumulator=grad_accumulator)
for _ in range(len(state.microbatches_grads_to_send)):
send_grads = state.microbatches_grads_to_send.popleft()
# Execute
send_grads()
# Make sure that micro batches are all fully consumed
state.check_buffers_empty()
return outputs
class OneForwardOneBackwardPipelineEngine(PipelineEngine):
def __init__(self):
super().__init__()
def train_batch_iter(
self,
model: torch_nn.Module,
pg: ProcessGroup,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
grad_accumulator: Optional[GradientAccumulator],
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
"""Check https://arxiv.org/abs/2104.04473 for diagrams for the pipeline engine"""
self.nb_microbatches = nb_microbatches
assert (
self.nb_microbatches >= pg.size() - 1
), f"Number of microbatches ({self.nb_microbatches}) must be at least PP_SIZE-1={pg.size() - 1} when using the OneForwardOneBackwardPipelineEngine"
state = PipelineTrainBatchState()
outputs = []
batch = iter(batch)
current_pp_rank = dist.get_rank(pg)
with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# Init
for _ in range(pg.size() - current_pp_rank - 1):
micro_batch = next(batch)
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Send tensors
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
# One backward
context = self._get_bwd_context(
model=model,
nb_backwards=state.nb_backwards,
grad_accumulator=grad_accumulator,
)
self.backward(context=context, state=state, grad_accumulator=grad_accumulator)
# Check figure in paper: The remain blocks are all backward and there is only `pg.size() - current_pp_rank - 1` blocks left
assert len(state.microbatches_activations_requiring_backward) == pg.size() - current_pp_rank - 1
# No more activation to send/recv
assert (
len(state.microbatches_activations_to_send) == 0
), f"There are activations left for me to send still: {len(state.microbatches_activations_to_send)}"
assert (
len(state.microbatches_activations_to_recv) == 0
), f"There are activations left for me to recv still: {len(state.microbatches_activations_to_recv)}"
# Close: compute backward for the rest
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_grads_to_send)):
send_grads = state.microbatches_grads_to_send.popleft()
# Execute
send_grads()
for _ in range(len(state.microbatches_activations_requiring_backward)):
context = self._get_bwd_context(
model=model,
nb_backwards=state.nb_backwards,
grad_accumulator=grad_accumulator,
)
self.backward(context=context, state=state, grad_accumulator=grad_accumulator)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_grads_to_send)):
send_grads = state.microbatches_grads_to_send.popleft()
# Execute
send_grads()
# Make sure that micro batches are all fully consumed
state.check_buffers_empty()
return outputs
import torch
from nanotron import logging
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState
logger = logging.get_logger(__name__)
class SendTensorToPipelineBuffer(torch.autograd.Function):
"""Make sending tensors differentiable. The difference is here we don't use `torch.distributed` primites, but store events that's we will pop whenever we need"""
@staticmethod
def forward(
ctx,
activation: torch.Tensor,
to_rank: int,
p2p: P2P,
pipeline_state: PipelineBatchState,
):
assert activation.requires_grad
ctx.p2p = p2p
ctx.to_rank = to_rank
ctx.pipeline_state = pipeline_state
# Send tensors
pipeline_state.register_send_activation(activation, to_rank=to_rank, p2p=p2p)
# HACK @thomasw21: This forces the trigger to backward
return torch.tensor(1, dtype=torch.float, device="cpu", requires_grad=True)
@staticmethod
def backward(ctx, grad_tensor):
p2p = ctx.p2p
to_rank = ctx.to_rank
pipeline_state = ctx.pipeline_state
# send a gradient and store it in buffer
pipeline_state.register_recv_grad(from_rank=to_rank, p2p=p2p)
if len(pipeline_state.grads_buffer) == 0:
pipeline_state.run_communication()
grad_tensor = pipeline_state.grads_buffer.popleft()
return grad_tensor, None, None, None
class SendTensorWithoutGradientToPipelineBuffer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
dummy_input: torch.Tensor,
activation: torch.Tensor,
to_rank: int,
p2p: P2P,
pipeline_state: PipelineBatchState,
):
assert dummy_input.requires_grad
assert activation.requires_grad is False
ctx.p2p = p2p
ctx.to_rank = to_rank
ctx.pipeline_state = pipeline_state
# Send tensors
pipeline_state.register_send_activation(activation, to_rank=to_rank, p2p=p2p)
# HACK @thomasw21: This forces the trigger to backward
return torch.tensor(1, dtype=torch.float, device="cpu", requires_grad=True)
@staticmethod
def backward(ctx, grad_tensor):
pipeline_state = ctx.pipeline_state
# send only the activations
pipeline_state.run_communication(send_only_activation=True)
return None, None, None, None, None
def send_to_pipeline_state_buffer(tensor: torch.Tensor, to_rank: int, p2p: P2P, pipeline_state: PipelineBatchState):
# This is used in order to know where to backward from.
if tensor.requires_grad:
result = SendTensorToPipelineBuffer.apply(tensor, to_rank, p2p, pipeline_state)
else:
# Trick that backward mechanism to just send the tensor.
dummy_input = torch.empty(1, dtype=torch.float, requires_grad=True, device="cpu")
result = SendTensorWithoutGradientToPipelineBuffer.apply(dummy_input, tensor, to_rank, p2p, pipeline_state)
pipeline_state.register_activation_requiring_backward(result)
class RecvTensorFromPipelineBuffer(torch.autograd.Function):
"""Make receiving tensors differentiable"""
@staticmethod
def forward(ctx, activation: torch.Tensor, from_rank: int, p2p: P2P, pipeline_state: PipelineBatchState):
ctx.pipeline_state = pipeline_state
ctx.p2p = p2p
ctx.from_rank = from_rank
return activation
@staticmethod
def backward(ctx, grad_tensor):
pipeline_state = ctx.pipeline_state
from_rank = ctx.from_rank
p2p = ctx.p2p
# Send tensors
pipeline_state.register_send_grad(grad_tensor, to_rank=from_rank, p2p=p2p)
return None, None, None, None
def recv_from_pipeline_state_buffer(from_rank: int, p2p: P2P, pipeline_state: PipelineBatchState):
pipeline_state.register_recv_activation(from_rank=from_rank, p2p=p2p)
if len(pipeline_state.activations_buffer) == 0:
pipeline_state.run_communication()
activation = pipeline_state.activations_buffer.popleft()
return RecvTensorFromPipelineBuffer.apply(activation, from_rank, p2p, pipeline_state)
import dataclasses
from typing import List, Sequence, Tuple
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage
logger = logging.get_logger(__name__)
FIRST_METADATA_SIZE = 7
SECOND_METADATA_SIZE = 1024
ID_TO_DTYPE = [
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
]
DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}
ID_TO_REQUIRES_GRAD = [True, False]
REQUIRES_GRAD_TO_ID = {value: id_ for id_, value in enumerate(ID_TO_REQUIRES_GRAD)}
ID_TO_IS_CONTIGUOUS = [True, False]
IS_CONTIGUOUS_TO_ID = {value: id_ for id_, value in enumerate(ID_TO_IS_CONTIGUOUS)}
@dataclasses.dataclass
class P2PTensorMetaData:
shape: Sequence[int]
stride: Sequence[int]
is_contiguous: bool
untyped_storage_size: int
storage_offset: int
dtype: torch.dtype
requires_grad: bool
def create_empty_storage(self, device: torch.device) -> torch.Tensor:
buffer = torch.empty(
size=(self.untyped_storage_size,),
requires_grad=False,
dtype=torch.int8,
device=device,
memory_format=torch.contiguous_format,
).view(dtype=self.dtype)
buffer.requires_grad = self.requires_grad
if self.is_contiguous:
buffer = buffer.as_strided(
size=tuple(self.shape), stride=tuple(self.stride), storage_offset=self.storage_offset
)
# Complex needs to be viewed as real first
# TODO @thomasw21: Find the issue with send/recv complex tensors
buffer = torch.view_as_real(buffer) if self.dtype.is_complex else buffer
return buffer
def reshape(self, buffer):
"""Changes the way we view buffer in order to fit metadata"""
# TODO @thomasw21: Find the issue with send/recv complex tensors
buffer = torch.view_as_complex(buffer) if self.dtype.is_complex else buffer
# Set shape and stride
if not self.is_contiguous:
buffer = buffer.as_strided(
size=tuple(self.shape), stride=tuple(self.stride), storage_offset=self.storage_offset
)
return buffer
@staticmethod
def to_first_metadata(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
# TODO @nouamane: avoid having two metadata comms, and preallocate shape/stride instead
return torch.tensor(
[
len(tensor.shape),
len(tensor.stride()),
IS_CONTIGUOUS_TO_ID[tensor.is_contiguous()],
get_untyped_storage(tensor).size(),
tensor.storage_offset(),
DTYPE_TO_ID[tensor.dtype],
REQUIRES_GRAD_TO_ID[tensor.requires_grad],
],
dtype=torch.long,
device=device,
)
@staticmethod
def to_second_metadata(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
return torch.tensor(tensor.shape + tensor.stride(), dtype=torch.long, device=device)
@classmethod
def from_metadata(cls, first_metadata: List[int], second_metadata: List[int]):
shape_and_stride = second_metadata
(
num_shape,
num_stride,
is_contiguous,
untyped_storage_size,
storage_offset,
dtype_id,
requires_grad_id,
) = first_metadata
return cls(
shape=shape_and_stride[: len(shape_and_stride) // 2],
stride=shape_and_stride[len(shape_and_stride) // 2 :],
is_contiguous=ID_TO_IS_CONTIGUOUS[is_contiguous],
untyped_storage_size=untyped_storage_size,
storage_offset=storage_offset,
dtype=ID_TO_DTYPE[dtype_id],
requires_grad=ID_TO_REQUIRES_GRAD[requires_grad_id],
)
def view_as_contiguous(tensor: torch.Tensor):
"""Given a tensor, we want to view the tensor as a contiguous storage"""
tensor_numel = tensor.numel()
tensor_element_size = tensor.element_size()
untyped_storage = get_untyped_storage(tensor)
untyped_storage_size = untyped_storage.size()
untyped_element_size = untyped_storage.element_size()
assert (
tensor_numel * tensor_element_size >= untyped_storage_size * untyped_element_size
), "Expect storage_size to be smaller than tensor size. It might not be true, when you use slicing for example though. We probably don't want to support it in our P2P system"
buffer = tensor_from_untyped_storage(untyped_storage=untyped_storage, dtype=tensor.dtype)
return buffer
class P2P:
def __init__(self, pg: dist.ProcessGroup, device: torch.device):
self.pg = pg
self.device = device
self.first_metadata = torch.empty(FIRST_METADATA_SIZE, dtype=torch.long, device=self.device)
self.second_metadata = torch.empty(SECOND_METADATA_SIZE, dtype=torch.long, device=self.device)
def _send_first_metadata_p2p_op(self, tensor: torch.Tensor, to_rank: int, tag: int = 0) -> dist.P2POp:
first_metadata = P2PTensorMetaData.to_first_metadata(tensor=tensor, device=self.device)
return dist.P2POp(
op=dist.isend,
tensor=first_metadata,
peer=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_first_metadata_p2p_op(self, from_rank: int, tag: int = 0) -> Tuple[torch.Tensor, dist.P2POp]:
first_metadata_buffer = torch.empty((FIRST_METADATA_SIZE,), dtype=torch.long, device=self.device)
return first_metadata_buffer, dist.P2POp(
op=dist.irecv,
tensor=first_metadata_buffer,
peer=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
def _send_second_metadata_p2p_op(self, tensor: torch.Tensor, to_rank: int, tag: int = 0) -> dist.P2POp:
second_metadata = P2PTensorMetaData.to_second_metadata(tensor=tensor, device=self.device)
return dist.P2POp(
op=dist.isend,
tensor=second_metadata,
peer=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_second_metadata_p2p_op(
self, shape_length: int, stride_length: int, from_rank: int, tag: int = 0
) -> Tuple[torch.Tensor, dist.P2POp]:
second_metadata_buffer = torch.empty((shape_length + stride_length,), dtype=torch.long, device=self.device)
return second_metadata_buffer, dist.P2POp(
op=dist.irecv,
tensor=second_metadata_buffer,
peer=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
def _send_data_p2p_op(self, tensor: torch.Tensor, to_rank: int, tag: int = 0) -> dist.P2POp:
return dist.P2POp(
op=dist.isend,
tensor=tensor,
peer=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_data_p2p_op(
self, tensor_metadata: P2PTensorMetaData, from_rank: int, tag: int = 0
) -> Tuple[torch.Tensor, dist.P2POp]:
tensor_buffer = tensor_metadata.create_empty_storage(self.device)
return tensor_buffer, dist.P2POp(
op=dist.irecv,
tensor=tensor_buffer,
peer=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
def _send_meta(self, tensor: torch.Tensor, to_rank: int, tag: int):
cpu_tensor = torch.tensor(
[
len(tensor.shape),
len(tensor.stride()),
IS_CONTIGUOUS_TO_ID[tensor.is_contiguous()],
get_untyped_storage(tensor).size(),
tensor.storage_offset(),
DTYPE_TO_ID[tensor.dtype],
REQUIRES_GRAD_TO_ID[tensor.requires_grad],
],
dtype=torch.long,
)
self.first_metadata.copy_(cpu_tensor)
dist.send(
self.first_metadata,
dst=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
second_metadata = tensor.shape + tensor.stride()
assert len(tensor.shape) == self.first_metadata[0]
assert len(tensor.stride()) == self.first_metadata[1]
# increase buffer size
if len(second_metadata) > len(self.second_metadata):
self.second_metadata = torch.empty(len(second_metadata), dtype=torch.long, device=self.device)
self.second_metadata[: len(second_metadata)].copy_(torch.tensor(second_metadata, dtype=torch.long))
dist.send(
self.second_metadata[: len(second_metadata)],
dst=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
def _recv_meta(self, from_rank: int, tag: int) -> P2PTensorMetaData:
dist.recv(
self.first_metadata,
src=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
(
num_shape,
num_stride,
is_contiguous,
untyped_storage_size,
storage_offset,
dtype_id,
requires_grad_id,
) = self.first_metadata
# self.pg.recv([second], from_rank, 0).wait() # more direct API
second_metadata_num_elements = num_shape + num_stride
# increase buffer size
if second_metadata_num_elements > len(self.second_metadata):
self.second_metadata = torch.empty(second_metadata_num_elements, dtype=torch.long, device=self.device)
dist.recv(
self.second_metadata[:second_metadata_num_elements],
src=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
shape = self.second_metadata[:num_shape]
stride = self.second_metadata[num_shape:second_metadata_num_elements]
return P2PTensorMetaData(
dtype=ID_TO_DTYPE[dtype_id],
requires_grad=ID_TO_REQUIRES_GRAD[requires_grad_id],
shape=shape,
stride=stride,
is_contiguous=ID_TO_IS_CONTIGUOUS[is_contiguous],
untyped_storage_size=untyped_storage_size,
storage_offset=storage_offset,
)
def isend_tensors(self, tensors: List[torch.Tensor], to_rank: int, tag: int = 0) -> List[dist.Work]:
futures = []
current_rank = dist.get_rank(self.pg)
logger.debug(f"Current rank {current_rank} sending to rank {to_rank}. Nb_tensors: {len(tensors)}")
for tensor in tensors:
if to_rank != current_rank:
self._send_meta(tensor, to_rank=to_rank, tag=tag)
if tensor.is_contiguous():
buffer = tensor
else:
# If the tensor is not contiguous we send the entire storage
buffer = view_as_contiguous(tensor)
# TODO @thomasw21: Find the issue with send/recv complex tensors
buffer = torch.view_as_real(buffer) if buffer.is_complex() else buffer
futures.append(
dist.isend(
buffer,
dst=dist.get_global_rank(group=self.pg, group_rank=to_rank),
group=self.pg,
tag=tag,
)
)
else:
raise ValueError("Tried sending tensor to itself")
return futures
def irecv_tensors(
self, num_tensors: int, from_rank: int, tag: int = 0
) -> Tuple[List[torch.Tensor], List[dist.Work]]:
futures = []
buffers = []
current_rank = dist.get_rank(self.pg)
logger.debug(f"Current rank {current_rank} receiving from rank {from_rank}. Nb_tensors: {num_tensors}")
for _ in range(num_tensors):
if from_rank != current_rank:
meta = self._recv_meta(from_rank=from_rank, tag=tag)
buffer = meta.create_empty_storage(device=self.device)
futures.append(
dist.irecv(
buffer,
src=dist.get_global_rank(group=self.pg, group_rank=from_rank),
group=self.pg,
tag=tag,
)
)
buffer = meta.reshape(buffer=buffer)
# Add to the list
buffers.append(buffer)
else:
raise ValueError("Tried receiving tensor from itself")
return buffers, futures
def send_tensors(self, tensors: List[torch.Tensor], to_rank: int, tag: int = 0):
futures = self.isend_tensors(tensors=tensors, to_rank=to_rank, tag=tag)
for future in futures:
future.wait()
def recv_tensors(self, num_tensors: int, from_rank: int, tag: int = 0) -> List[torch.Tensor]:
buffers, futures = self.irecv_tensors(num_tensors=num_tensors, from_rank=from_rank, tag=tag)
for future in futures:
future.wait()
return buffers
class BatchTensorSendRecvState:
"""
This class is used to register send/recv batches of tensors, and
then executes send/recv in `flush()` calls. This is useful for
amortizing the cost of sending and receiving tensors over multiple
iterations.
"""
p2p: P2P
first_metadata_p2p_ops: List[dist.P2POp]
second_metadata_p2p_ops: List[dist.P2POp]
data_p2p_ops: List[dist.P2POp]
recv_first_metadata_buffers: List[torch.Tensor]
recv_from_ranks: List[int]
def __init__(self, p2p: P2P):
self.p2p = p2p
self._reset()
def _reset(self):
self.first_metadata_p2p_ops: List[dist.P2POp] = []
self.second_metadata_p2p_ops: List[dist.P2POp] = []
self.data_p2p_ops: List[dist.P2POp] = []
self.recv_first_metadata_buffers: List[torch.Tensor] = []
self.recv_from_ranks: List[int] = []
def __str__(self):
return f"BatchTensorSendRecvState(first_metadata_p2p_ops={len(self.first_metadata_p2p_ops)}, second_metadata_p2p_ops={len(self.second_metadata_p2p_ops)}, data_p2p_ops={len(self.data_p2p_ops)}, recv_first_metadata_buffers={len(self.recv_first_metadata_buffers)}, recv_from_ranks={self.recv_from_ranks})"
def add_send(self, tensor: torch.Tensor, to_rank: int, tag: int = 0):
self.first_metadata_p2p_ops.append(
self.p2p._send_first_metadata_p2p_op(tensor=tensor, to_rank=to_rank, tag=tag)
)
self.second_metadata_p2p_ops.append(
self.p2p._send_second_metadata_p2p_op(tensor=tensor, to_rank=to_rank, tag=tag)
)
self.data_p2p_ops.append(
self.p2p._send_data_p2p_op(tensor=view_as_contiguous(tensor), to_rank=to_rank, tag=tag)
)
def add_recv(self, from_rank: int, tag: int = 0) -> int:
"""
Only add p2p ops for the first operation, as `_recv_second_metadata` and `_recv_data_p2p_op`
require results from the first metadata to be transfered first.
Return: index of the recv_buffer in `self.recv_first_metadata_buffers`
"""
buffer, recv_op = self.p2p._recv_first_metadata_p2p_op(from_rank=from_rank, tag=tag)
self.first_metadata_p2p_ops.append(recv_op)
self.recv_first_metadata_buffers.append(buffer)
self.recv_from_ranks.append(from_rank)
return len(self.recv_first_metadata_buffers) - 1
def _send_recv_first_metadata(self) -> List[List[int]]:
# Send/Recv first metadata
reqs = dist.batch_isend_irecv(self.first_metadata_p2p_ops)
for req in reqs:
req.wait()
# We want an early cpu/gpu sync here as we are right after the wait so it's nearly free.
# Removing the tolist call here delays the sync and will impact performance.
# We need to instantiate it in a list because it is used twice
first_metadatas = [tensor.tolist() for tensor in self.recv_first_metadata_buffers]
return first_metadatas
def _send_recv_second_metadata(self, first_metadata: List[List[int]]) -> List[List[int]]:
# turn a list of tuple into a tuple of list
recv_second_metadata_buffers, recv_second_metadata_ops = zip(
*(
self.p2p._recv_second_metadata_p2p_op(
shape_length=num_shape, stride_length=num_stride, from_rank=from_rank
)
for (num_shape, num_stride, *_), from_rank in zip(first_metadata, self.recv_from_ranks)
)
)
recv_second_metadata_ops = list(recv_second_metadata_ops)
# Send/Recv second metadata
reqs = dist.batch_isend_irecv(self.second_metadata_p2p_ops + recv_second_metadata_ops)
for req in reqs:
req.wait()
# We want an early cpu/gpu sync here as we are right after the wait so it's nearly free.
# Removing the tolist call here delays the sync and will impact performance.
second_metadatas = [tensor.tolist() for tensor in recv_second_metadata_buffers]
return second_metadatas
def _send_recv_data(self, tensor_metadatas: List[P2PTensorMetaData]) -> List[torch.Tensor]:
# turn a list of tuples into a tuple of list
recv_data_buffers, recv_data_ops = zip(
*(
self.p2p._recv_data_p2p_op(tensor_metadata=tensor_metadata, from_rank=from_rank)
for tensor_metadata, from_rank in zip(tensor_metadatas, self.recv_from_ranks)
)
)
recv_data_ops = list(recv_data_ops)
# Send/Recv tensor data
futures = dist.batch_isend_irecv(self.data_p2p_ops + recv_data_ops)
for future in futures:
future.wait()
# Format tensor by setting the stride
return [
recv_data_buffer.as_strided(size=tuple(tensor_metadata.shape), stride=tuple(tensor_metadata.stride))
for recv_data_buffer, tensor_metadata in zip(recv_data_buffers, tensor_metadatas)
]
def flush(self) -> List[torch.Tensor]:
"""
Run all communication in a batch.
Return `torch.Tensor` in the case of recv.
"""
assert len(self.recv_first_metadata_buffers) == len(
self.recv_from_ranks
), f"len(self.recv_first_metadata_buffers)={len(self.recv_first_metadata_buffers)}, len(self.recv_from_ranks)={len(self.recv_from_ranks)} but should be equal."
# If there is no communication, return
if len(self.first_metadata_p2p_ops) == 0:
return []
# If there is no recv
if len(self.recv_first_metadata_buffers) == 0:
reqs = dist.batch_isend_irecv(
self.first_metadata_p2p_ops + self.second_metadata_p2p_ops + self.data_p2p_ops
)
for req in reqs:
req.wait()
self._reset()
return []
# Send/Recv first metadata
logger.debug(f"First metadata: {[p2pop.op for p2pop in self.first_metadata_p2p_ops]}")
# TODO(kunhao): We could actually send all at once like the above no recv case. But I need to benchmark the performance.
first_metadatas = self._send_recv_first_metadata()
# Send/Recv second metadata
second_metadatas = self._send_recv_second_metadata(first_metadatas)
tensor_metadatas = [
P2PTensorMetaData.from_metadata(first_metadata, second_metadata)
for first_metadata, second_metadata in zip(first_metadatas, second_metadatas)
]
recv_tensors = self._send_recv_data(tensor_metadatas)
# Reset state
self._reset()
return recv_tensors
import collections
import dataclasses
from abc import ABC, abstractmethod
from typing import List
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.p2p import P2P
logger = logging.get_logger(__name__)
@dataclasses.dataclass
class SendActivation:
activation: torch.Tensor
to_rank: int
p2p: P2P
def __call__(self):
self.p2p.send_tensors([self.activation], to_rank=self.to_rank)
@dataclasses.dataclass
class RecvActivation:
from_rank: int
p2p: P2P
def __call__(self) -> torch.Tensor:
return self.p2p.recv_tensors(num_tensors=1, from_rank=self.from_rank)[0]
@dataclasses.dataclass
class SendGrad:
grad: torch.Tensor
to_rank: int
p2p: P2P
def __call__(self):
self.p2p.send_tensors([self.grad], to_rank=self.to_rank)
@dataclasses.dataclass
class RecvGrad:
from_rank: int
p2p: P2P
def __call__(self) -> torch.Tensor:
return self.p2p.recv_tensors(num_tensors=1, from_rank=self.from_rank)[0]
class PipelineBatchState(ABC):
activations_buffer = collections.deque()
@abstractmethod
def register_activation_requiring_backward(self, activation: torch.Tensor):
...
@abstractmethod
def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
...
@abstractmethod
def register_recv_activation(self, from_rank: int, p2p: P2P):
...
@abstractmethod
def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
...
@abstractmethod
def register_recv_grad(self, from_rank: int, p2p: P2P):
...
@abstractmethod
def run_communication(self, send_only_activation: bool = False):
...
@abstractmethod
def new_micro_batch_forward(self):
...
@abstractmethod
def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
...
@dataclasses.dataclass
class PipelineTrainBatchState(PipelineBatchState):
microbatches_activations_to_send = collections.deque()
microbatches_activations_to_recv = collections.deque()
microbatches_grads_to_send = collections.deque()
microbatches_grads_to_recv = collections.deque()
grads_buffer = collections.deque()
# List of list, first index represent micro_batch_id, second index represent activations that needs to be popped
microbatches_activations_requiring_backward = collections.deque()
# Reinitialise counter
nb_backwards = 0
nb_forwards = 0
def register_activation_requiring_backward(self, activation: torch.Tensor):
# Register the activation to last microbatch
self.microbatches_activations_requiring_backward[-1].append(activation)
def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only send activations from higher ranks
self.microbatches_activations_to_send.append(SendActivation(activation=activation, to_rank=to_rank, p2p=p2p))
def register_recv_activation(self, from_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only recv activations from lower ranks
self.microbatches_activations_to_recv.append(RecvActivation(from_rank=from_rank, p2p=p2p))
def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only send gradients to lower ranks
self.microbatches_grads_to_send.append(SendGrad(grad=grad, to_rank=to_rank, p2p=p2p))
def register_recv_grad(self, from_rank: int, p2p: P2P):
# TODO @thomasw21: We assume that each rank has a single contiguous list of blocks. This also means that we only recv gradients from higher ranks
self.microbatches_grads_to_recv.append(RecvGrad(from_rank=from_rank, p2p=p2p))
def run_communication(self, send_only_activation: bool = False):
"""Run communication in a specific order: send activation, recv activation, send grad, recv grad
Only one communication is done at a time."""
log_rank(
f"activation_to_send: {len(self.microbatches_activations_to_send)} | "
f"activation_to_recv: {len(self.microbatches_activations_to_recv)} | "
f"grads_to_send: {len(self.microbatches_grads_to_send)} | "
f"grads_to_recv: {len(self.microbatches_grads_to_recv)} | "
f"activation_buffer: {len(self.activations_buffer)} | "
f"grads_buffer: {len(self.grads_buffer)}",
logger=logger,
level=logging.DEBUG,
)
# Pop one send activation
if len(self.microbatches_activations_to_send) > 0:
send_activation = self.microbatches_activations_to_send.popleft()
# Execute
activation_send_requires_grad = send_activation.activation.requires_grad
send_activation()
if send_only_activation:
return
# Pop one recv activation
if len(self.microbatches_activations_to_recv) > 0:
recv_activation = self.microbatches_activations_to_recv.popleft()
# Execute
recv_activation_tensor = recv_activation()
self.activations_buffer.append(recv_activation_tensor)
# If somehow you receive a tensor without the need of backward, you shouldn't do cross communication
if recv_activation_tensor.requires_grad is False:
return
# Pop one send gradient
if len(self.microbatches_grads_to_send) > 0:
send_grad = self.microbatches_grads_to_send.popleft()
# Execute
send_grad()
# Pop one recv gradient
if len(self.microbatches_grads_to_recv) > 0:
# Send activation until `activation_send_requires_grad` is True
while len(self.microbatches_activations_to_send) > 0 and not activation_send_requires_grad:
send_activation = self.microbatches_activations_to_send.popleft()
# Execute
activation_send_requires_grad = send_activation.activation.requires_grad
send_activation()
recv_grad = self.microbatches_grads_to_recv.popleft()
# Execute
self.grads_buffer.append(recv_grad())
# TODO @thomasw21: I need some mechanism to point to whatever is now sorted in a buffer, typically some id that would point to the correct tensor in our buffer instead of relying on the sorted list.
def new_micro_batch_forward(self):
self.microbatches_activations_requiring_backward.append(collections.deque())
def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
return self.microbatches_activations_requiring_backward.popleft()
def check_buffers_empty(self):
assert (
len(self.microbatches_activations_requiring_backward) == 0
), f"There are still activations that require backward: {len(self.microbatches_activations_requiring_backward)}"
assert (
len(self.microbatches_activations_to_send) == 0
), f"There are activations left for me to send still: {len(self.microbatches_activations_to_send)}"
assert (
len(self.microbatches_activations_to_recv) == 0
), f"There are activations left for me to recv still: {len(self.microbatches_activations_to_recv)}"
assert (
len(self.microbatches_grads_to_send) == 0
), f"There are gradients left for me to send still: {len(self.microbatches_grads_to_send)}"
assert (
len(self.microbatches_grads_to_recv) == 0
), f"There are gradients left for me to recv still: {len(self.microbatches_grads_to_recv)}"
@dataclasses.dataclass
class PipelineEvalBatchState(PipelineBatchState):
microbatches_activations_to_send = collections.deque()
microbatches_activations_to_recv = collections.deque()
activations_buffer = collections.deque()
def register_activation_requiring_backward(self, activation: torch.Tensor):
pass
def register_send_activation(self, activation: torch.Tensor, to_rank: int, p2p: P2P):
self.microbatches_activations_to_send.append(SendActivation(activation=activation, to_rank=to_rank, p2p=p2p))
# There's a cross communication
if len(self.microbatches_activations_to_recv) > 0 and len(self.microbatches_activations_to_recv) > 0:
self.run_communication()
def register_recv_activation(self, from_rank: int, p2p: P2P):
self.microbatches_activations_to_recv.append(RecvActivation(from_rank=from_rank, p2p=p2p))
# There's a cross communication
if len(self.microbatches_activations_to_recv) > 0 and len(self.microbatches_activations_to_recv) > 0:
self.run_communication()
def register_send_grad(self, grad: torch.Tensor, to_rank: int, p2p: P2P):
raise NotImplementedError("You can't register a send grad in pipeline eval mode")
def register_recv_grad(self, from_rank: int, p2p: P2P):
raise NotImplementedError("You can't register a recv grad in pipeline eval mode")
def new_micro_batch_forward(self):
pass
def pop_last_activations_requiring_backward(self) -> List[torch.Tensor]:
pass
def run_communication(self, send_only_activation: bool = False):
# four cases:
# - you receive from higher rank and you send to higher rank
# - You receive from higher rank and you send to lower rank
# - you receive from lower rank and you send to higher rank
# - you receive from lower rank and you send to lower rank
send_activation = None
# Pop all send activation
for _ in range(min(1, len(self.microbatches_activations_to_send))):
send_activation = self.microbatches_activations_to_send.popleft()
# Pop all recv activation
recv_activation = None
for _ in range(min(1, len(self.microbatches_activations_to_recv))):
recv_activation = self.microbatches_activations_to_recv.popleft()
if send_activation is None:
if recv_activation is None:
raise ValueError("Why the hell do we communicate when there's nothing to communicate?")
self.activations_buffer.append(recv_activation())
else:
if recv_activation is None:
send_activation()
else:
# Define in which order to we do it.
# Actually we can't do any heuristics as you need global information in order to define clear ordering.
# We make a BIG assumption that only ONE rank receives from higher rank and sends to higher rank.
# In this case we find the "lowest" rank, send first
# All the other ranks receive first and send after
# Lowest rank receives.
# If we knew who was involved in the cycle, we could just randomly choose one rank to first send then recv, however it's not clear who's involved
p2p = send_activation.p2p
assert p2p == recv_activation.p2p
is_lowest = send_activation.to_rank > dist.get_rank(
p2p.pg
) and recv_activation.from_rank > dist.get_rank(p2p.pg)
if is_lowest:
send_activation()
self.activations_buffer.append(recv_activation())
else:
self.activations_buffer.append(recv_activation())
send_activation()
def check_buffers_empty(self):
assert (
len(self.microbatches_activations_to_send) == 0
), f"There are activations left for me to send still: {len(self.microbatches_activations_to_send)}"
assert (
len(self.microbatches_activations_to_recv) == 0
), f"There are activations left for me to recv still: {len(self.microbatches_activations_to_recv)}"
assert (
len(self.activations_buffer) == 0
), f"There are activations left in the buffer: {len(self.activations_buffer)}"
import dataclasses
@dataclasses.dataclass
class TensorPointer:
"""Dataclass specifying from which rank we need to query a tensor from in order to access data"""
# Needed to understand from which rank to get the tensor
# TODO @thomasw21: Maybe add which group it belongs to as well? Typically this is highly correlated to `p2p.pg`
group_rank: int
# TODO @thomasw21: Maybe add a tag (torch.distributed.send/recv allow for tagging)
from nanotron.models import NanotronModel
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from torch import nn
from torch.nn.parallel import DistributedDataParallel
def get_input_output_pp_ranks(model: NanotronModel | DistributedDataParallel):
if isinstance(model, DistributedDataParallel):
input_pp_rank = model.module.input_pp_rank
output_pp_rank = model.module.output_pp_rank
else:
input_pp_rank = model.input_pp_rank
output_pp_rank = model.output_pp_rank
return input_pp_rank, output_pp_rank
def get_pp_rank_of(target: str, module: nn.Module):
"""Assuming a model with pipeline blocks, we want to know in which pp rank the module/parameter whose name is `target`"""
if isinstance(module, PipelineBlock):
return module.rank
atoms = target.split(".")
current_module = module
for atom in atoms:
if not hasattr(current_module, atom):
raise AttributeError(f'{current_module._get_name()} has no attribute `"{atom}"`')
current_module = getattr(current_module, atom)
if isinstance(current_module, PipelineBlock):
return current_module.rank
if not isinstance(current_module, nn.Module):
raise AttributeError(f'`"{atom}"` is not an nn.Module')
raise ValueError(f'`"{target}" is not inside a PipelineBlock and thus does not have a pp_rank')
import dataclasses
from typing import List, Optional, Tuple
import numpy as np
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel.parameters import NanotronParameter, SlicesPair
@dataclasses.dataclass
class SplitConfig:
split_dim: int
# contiguous_chunks is a tuple of chunk sizes along the split_dim
# sharding happens inside each chunk
# if None, by default contiguous_chunks = (len(unsharded_param.shape[split_dim]),)
contiguous_chunks: Optional[Tuple[int, ...]] = None
def create_sharded_parameter(
parameter: nn.Parameter,
global_ranks: Tuple[int, ...],
local_global_slices_pairs: Tuple[SlicesPair, ...],
unsharded_shape: Tuple[int, ...],
) -> NanotronParameter:
if not isinstance(parameter, NanotronParameter):
parameter = NanotronParameter(tensor=parameter)
parameter.mark_as_sharded(
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
)
return parameter
def create_sharded_parameter_from_config(
parameter: nn.Parameter,
pg: dist.ProcessGroup,
split_config: SplitConfig,
) -> NanotronParameter:
current_rank = dist.get_rank(pg)
param_num_dims = len(parameter.shape)
global_ranks = dist.get_global_ranks(pg)
split_dim = split_config.split_dim
assert split_dim < param_num_dims
contiguous_chunks = split_config.contiguous_chunks
if contiguous_chunks is None:
# we are assuming that the parameter is contiguous along the split_dim, i.e. 1 whole chunk
# all parameters are equally shardable across the process group along the split_dim
shard_length = parameter.shape[split_dim]
global_slice = slice(current_rank * shard_length, (current_rank + 1) * shard_length)
# construct a mapping from local slices to global slices, multi-dimensional version
local_slices = tuple(slice(None) for _ in range(param_num_dims))
global_slices = tuple(global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims))
local_global_slices_pairs = (SlicesPair(local_slices=local_slices, global_slices=global_slices),)
unsharded_shape = tuple(
pg.size() * param_dim_size if dim_id == split_dim else param_dim_size
for dim_id, param_dim_size in enumerate(parameter.shape)
)
else:
# support custom contiguous chunk size for sharding each along the split_dim
local_global_slices_pairs: List[SlicesPair] = []
chunks_global_offset = np.cumsum((0,) + contiguous_chunks)
chunks_local_offset = chunks_global_offset // pg.size()
for chunk, chunk_global_start, chunk_local_start, chunk_local_end in zip(
contiguous_chunks,
chunks_global_offset[:-1],
chunks_local_offset[:-1],
chunks_local_offset[1:],
strict=True,
):
# we assume that we are doing equal split at the chunk level
assert chunk % pg.size() == 0, f"chunk size {chunk} must be divisible by process group size {pg.size()}"
shard_length = chunk // pg.size()
# we have: chunk_local_end = chunk_local_start + shard_length
local_slice = slice(chunk_local_start, chunk_local_end)
global_slice = slice(
current_rank * shard_length + chunk_global_start,
(current_rank + 1) * shard_length + chunk_global_start,
)
local_slices = tuple(
local_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims)
)
global_slices = tuple(
global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims)
)
local_global_slices_pairs.append(SlicesPair(local_slices=local_slices, global_slices=global_slices))
local_global_slices_pairs: Tuple[SlicesPair, ...] = tuple(local_global_slices_pairs)
unsharded_shape = tuple(
chunks_global_offset[-1] if dim_id == split_dim else param_dim_size
for dim_id, param_dim_size in enumerate(parameter.shape)
)
return create_sharded_parameter(
parameter=parameter,
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
)
def mark_all_parameters_in_module_as_sharded(module: nn.Module, pg: dist.ProcessGroup, split_config: SplitConfig):
"""
Mark parameters as sharded within a module. We assume that parameters are equally shardable across the process group.
:param module: nn.Module
:param pg: dist.ProcessGroup
:param split_config: SplitConfig
:return:
"""
for module_name, submodule in module.named_modules():
for param_name, param in list(submodule.named_parameters(recurse=False)):
new_param = create_sharded_parameter_from_config(parameter=param, pg=pg, split_config=split_config)
setattr(submodule, param_name, new_param)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment