test_accelerator.py 680 Bytes
Newer Older
1
2
3
from functools import partial

import torch.multiprocessing as mp
4
5
6
import torch.nn as nn

from colossalai.booster.accelerator import Accelerator
7
from colossalai.testing import parameterize, rerun_if_address_is_in_use
8
9


10
11
@parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device):
12
13
14
15
    acceleartor = Accelerator(device)
    model = nn.Linear(8, 8)
    model = acceleartor.configure_model(model)
    assert next(model.parameters()).device.type == device
16
17
18
19
20
21
22
23
24
25
26
27
    del model, acceleartor


def run_dist(rank):
    run_accelerator()


@rerun_if_address_is_in_use()
def test_accelerator():
    world_size = 1
    run_func = partial(run_dist)
    mp.spawn(run_func, nprocs=world_size)