tensor_pointer.py 452 Bytes
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
import dataclasses


@dataclasses.dataclass
class TensorPointer:
    """Dataclass specifying from which rank we need to query a tensor from in order to access data"""

    # Needed to understand from which rank to get the tensor
    # TODO @thomasw21: Maybe add which group it belongs to as well? Typically this is highly correlated to `p2p.pg`
    group_rank: int
    # TODO @thomasw21: Maybe add a tag (torch.distributed.send/recv allow for tagging)