import torch def p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): """Point-to-point communications of KV and dKV in Attention with context parallelism""" send_recv_ops = [] if batch_p2p_comm: # int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) 为啥呢 if rank % 2 == 0: send_op = torch.distributed.P2POp( torch.distributed.isend, send_tensor, send_dst, cp_group ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, recv_src, cp_group ) send_recv_ops.append(send_op) send_recv_ops.append(recv_op) else: recv_op = torch.distributed.P2POp( torch.distributed.irecv, recv_tensor, recv_src, cp_group ) send_op = torch.distributed.P2POp( torch.distributed.isend, send_tensor, send_dst, cp_group ) send_recv_ops.append(recv_op) send_recv_ops.append(send_op) send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops) else: if rank % 2 == 0: send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) send_recv_ops.append(send_op) send_recv_ops.append(recv_op) else: recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group) send_op = torch.distributed.isend(send_tensor, send_dst, cp_group) send_recv_ops.append(recv_op) send_recv_ops.append(send_op) send_recv_reqs = send_recv_ops return send_recv_reqs