lazy_init_utils.py 3.47 KB
Newer Older
1
import random
2
from copy import deepcopy
3
4
5
6
from typing import Any, Callable, Optional, Tuple

import numpy as np
import torch
7
from packaging import version
8

9
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
10
from colossalai.tensor.d_tensor.layout_converter import to_global
11
12
from tests.kit.model_zoo.registry import ModelAttribute

13
14
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')

15
16
17
18
19
20
21
22
23
24
# model_fn, data_gen_fn, output_transform_fn, model_attr
TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]]


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


25
def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
26
27
28
29
30
31
32
33
34
    s1 = m1.state_dict()
    s2 = m2.state_dict()

    assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}'

    for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()):
        assert n1 == n2
        assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'

35
36
37
    for p1, p2 in zip(m1.parameters(), m2.parameters()):
        assert p1.requires_grad == p2.requires_grad

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
                         output_transform_fn: Callable[[Any], dict]) -> None:
    data = data_gen_fn()

    m1.eval()
    m2.eval()
    # run forward
    with torch.no_grad():
        outputs1 = m1(**data)
        outputs2 = m2(**data)

    # compare output
    transformed_out1 = output_transform_fn(outputs1)
    transformed_out2 = output_transform_fn(outputs2)

    assert len(transformed_out1) == len(transformed_out2)

    for key, out1 in transformed_out1.items():
        out2 = transformed_out2[key]
        assert torch.allclose(out1, out2, atol=1e-5), \
            f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'


def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
    model_fn, data_gen_fn, output_transform_fn, model_attr = entry
    _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
    LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
    ctx = LazyInitContext(tensor_cls=_MyTensor)
    with ctx:
        model = model_fn()
    ctx = LazyInitContext()
    with ctx:
        deferred_model = model_fn()
72
        copied_deferred_model = deepcopy(deferred_model)
73
    deferred_model = ctx.materialize(deferred_model, verbose=verbose)
74
    copied_deferred_model = ctx.materialize(copied_deferred_model, verbose=verbose)
75
    assert_model_equal(model, deferred_model)
76
    assert_model_equal(deferred_model, copied_deferred_model)
77
78
    if check_forward:
        assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
79
        assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn)
80
81
    if verbose:
        print(f'{model.__class__.__name__} pass')
82
83


84
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
85
86
87
88
89
90
91
92
93
    state = model.state_dict()
    distributed_state = distributed_model.state_dict()

    assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}'

    for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()):
        assert n1 == n2
        t1 = t1.cuda()
        t2 = t2.cuda()
94
95
        if n2 in layout_dict:
            t2 = to_global(t2, layout_dict[n2])
96
        assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'