test_context.py 973 Bytes
Newer Older
ver217's avatar
ver217 committed
1
import pytest
2
from colossalai.utils.model.colo_init_context import ColoInitContext
3
4
5

import torch

6
from colossalai.utils.cuda import get_current_device
7

8

ver217's avatar
ver217 committed
9
10
@pytest.mark.skip
# FIXME(ver217): support lazy init
11
def test_lazy_init():
12
13
14
15
16
17
18
19
20
21
22
23
24
    in_dim = 4
    out_dim = 5

    with ColoInitContext(lazy_memory_allocate=True) as ctx:
        fc = torch.nn.Linear(in_dim, out_dim, bias=True)

    # lazy_memory_allocate=True, no payload is maintained
    assert fc.weight._torch_tensor.numel() == 0

    fc.weight.torch_tensor()
    assert fc.weight._torch_tensor.numel() == in_dim * out_dim


ver217's avatar
ver217 committed
25
@pytest.mark.skip
26
27
28
29
30
31
32
33
34
35
36
37
def test_device():
    in_dim = 4
    out_dim = 5

    with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx:
        fc = torch.nn.Linear(in_dim, out_dim, bias=True)

    # eval an lazy parameter
    fc.weight.torch_tensor()
    assert fc.weight.device == get_current_device()


38
if __name__ == '__main__':
39
40
    test_lazy_init()
    test_device()