# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Communications utilities.""" import torch from megatron.core import parallel_state from megatron.core import mpu # TODO: use functions from megatron/p2p def recv_from_prev_pipeline_rank_(recv_buffer=None): """Receive from previous pipeline stage and update the input buffer inplace.""" if not mpu.is_pipeline_first_stage(): assert recv_buffer is not None recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, recv_buffer, mpu.get_pipeline_model_parallel_prev_rank()) reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() # TODO: use functions from megatron/p2p def send_to_next_pipeline_rank(tensor=None): """Send output to the next pipeline stage.""" if not mpu.is_pipeline_last_stage(): assert tensor is not None send_next_op = torch.distributed.P2POp( torch.distributed.isend, tensor, mpu.get_pipeline_model_parallel_next_rank()) reqs = torch.distributed.batch_isend_irecv([send_next_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() def _is_cuda(tensor): """Check if a tensor is not none and is cuda.""" assert tensor is not None assert tensor.is_cuda def _is_cuda_contiguous(tensor): """Check if a tensor is not none, is cuda, and is contiguous.""" _is_cuda(tensor) assert tensor.is_contiguous() def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" is_last_stage = mpu.is_pipeline_last_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if mpu.is_pipeline_first_stage() and is_last_stage: return tensor if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) # Get the group and corresponding source rank. src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(tensor, src, group) return tensor def _send_and_recv_from_last_to_first_pipeline_stage(tensor=None): is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() if is_last_stage or is_first_stage: if is_first_stage: recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, tensor, mpu.get_pipeline_model_parallel_last_rank()) reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) elif is_last_stage: send_next_op = torch.distributed.P2POp( torch.distributed.isend, tensor, mpu.get_pipeline_model_parallel_first_rank()) reqs = torch.distributed.batch_isend_irecv([send_next_op]) for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() return tensor def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return tensor # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) tensor = _send_and_recv_from_last_to_first_pipeline_stage(tensor) else: tensor = None return tensor def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: _is_cuda(tensor) is_contiguous = tensor.is_contiguous() if is_contiguous: tensor_ = tensor else: if is_last_stage: tensor_ = tensor.contiguous() else: tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) tensor_ = _send_and_recv_from_last_to_first_pipeline_stage(tensor_) # Update the first stage tensor if is_first_stage and not is_contiguous: tensor[...] = tensor_ def broadcast_tensor(size, dtype, tensor=None, rank=0, data_parallel=False): """Given size and type of a tensor on all ranks and the tensor value only on a specific rank, broadcast from that rank to all other ranks. Args: data_parallel (bool): Broadcast across a single data parallel model replica. """ if data_parallel: rank = parallel_state.get_model_parallel_src_rank() if torch.distributed.get_rank() == rank: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) group = None if data_parallel: group = parallel_state.get_model_parallel_group() torch.distributed.broadcast(tensor, rank, group=group) return tensor def broadcast_list(size, dtype, list_values=None, rank=0, data_parallel=False): """Broadcast a list of values with a given type. Args: data_parallel (bool): Broadcast across a single data parallel model replica. """ tensor = None if data_parallel: if parallel_state.get_model_parallel_src_rank() == torch.distributed.get_rank(): tensor = torch.tensor(list_values, dtype=dtype, device=torch.cuda.current_device()) rank = parallel_state.get_model_parallel_src_rank() else: if torch.distributed.get_rank() == rank: tensor = torch.tensor(list_values, dtype=dtype, device=torch.cuda.current_device()) return broadcast_tensor(size, dtype, tensor=tensor, rank=rank, data_parallel=data_parallel) def broadcast_int_list(size, int_list=None, rank=0, data_parallel=False): """Broadcast a list of integer values. Args: data_parallel (bool): Broadcast across a single data parallel model replica. """ return broadcast_list(size, torch.int64, list_values=int_list, rank=rank, data_parallel=data_parallel) def broadcast_float_list(size, float_list=None, rank=0, data_parallel=False): """Broadcast a list of float values. Args: data_parallel (bool): Broadcast across a single data parallel model replica. """ return broadcast_list(size, torch.float32, list_values=float_list, rank=rank, data_parallel=data_parallel)