"""Most granular pipeline block, ie within this module, everything will be part of a single rank, ie the entire computation within this block will happen on a specific device.
Current limitations:
- PipelineBlocks have to wrap a method/function/module that outputs a Dict[str, torch.Tensor]
Some considerations:
- In the literature, authors often refer to pipeline stages as a granularity block. Our notion is more granular. A pipeline stage is list of contiguous (in the forward sense) of pipeline blocks.
All PipelineBlock definition exist in each rank, they are just instantiated/built on a single rank per pipeline parallel process group.
# Send activations from other devices to local rank
forname,tensorinsorted_kwargs:
ifisinstance(tensor,TensorPointer):
# Current rank is neither the rank holding the data nor the rank responsible for computing block
continue
else:
assertisinstance(tensor,torch.Tensor)
# We need to send the tensor to the rank that actually runs the compute
ifself.pipeline_stateisnotNone:
send_to_pipeline_state_buffer(
tensor,
to_rank=self.rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
iftensor.requires_gradisTrue:
raiseValueError(
f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
# This assumes that prior communication was already done
# In case of interleaved 1f1b, if this is the second model chunk, then we need to send the previous activations before receiving the current activations
f"Pipeline engine is None and tensor requires grad. Tried receiving a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
"""If model returns tensor, we use it as a loss to backpropagate. If model returns a dict, we assume that the key "loss" is the loss to backpropagate."""
"""Make sending tensors differentiable. The difference is here we don't use `torch.distributed` primites, but store events that's we will pop whenever we need"""
),"Expect storage_size to be smaller than tensor size. It might not be true, when you use slicing for example though. We probably don't want to support it in our P2P system"
),f"len(self.recv_first_metadata_buffers)={len(self.recv_first_metadata_buffers)}, len(self.recv_from_ranks)={len(self.recv_from_ranks)} but should be equal."
# TODO @thomasw21: I need some mechanism to point to whatever is now sorted in a buffer, typically some id that would point to the correct tensor in our buffer instead of relying on the sorted list.