_util.py 757 Bytes
Newer Older
Ziyue Jiang's avatar
Ziyue Jiang committed
1
2
3
import torch
import torch.distributed as dist

ver217's avatar
ver217 committed
4

Ziyue Jiang's avatar
Ziyue Jiang committed
5
6
7
def check_equal(A, B):
    assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True

ver217's avatar
ver217 committed
8

Ziyue Jiang's avatar
Ziyue Jiang committed
9
10
11
12
13
14
15
16
17
18
def replace_parameter_add_grad(layer, weight=None, bias=None):
    if weight is not None:
        delattr(layer, 'weight')
        setattr(layer, 'weight', weight)
        layer.weight.requires_grad = True
    if bias is not None:
        delattr(layer, 'bias')
        setattr(layer, 'bias', bias)
        layer.bias.requires_grad = True

ver217's avatar
ver217 committed
19

Ziyue Jiang's avatar
Ziyue Jiang committed
20
21
22
def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
    dist.broadcast(tensor, src=0)
    tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank]
ver217's avatar
ver217 committed
23
24
25
26
27
    return tensor_chunk.clone()


def tensor_equal(A, B):
    return torch.allclose(A, B, rtol=1e-3, atol=1e-1)