helpers.py 679 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F


def dist_init(rank, world_size):
    backend = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
    print(f"Using backend: {backend}")
    dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)


13
def get_model():
14
15
16
    return nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))


17
def get_data(n_batches=1):
18
19
20
    return [(torch.randn(20, 10), torch.randint(0, 2, size=(20, 1)).squeeze()) for i in range(n_batches)]


21
def get_loss_fun():
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
22
    return F.nll_loss