test_checkpoint.py 3.19 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import tempfile
from contextlib import nullcontext
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.models.gpt import GPTActor
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port

GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)


def get_data(batch_size: int, seq_len: int = 10) -> dict:
    input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
    attention_mask = torch.ones_like(input_ids)
    return dict(input_ids=input_ids, attention_mask=attention_mask)


def run_test_checkpoint(strategy):
    BATCH_SIZE = 2

    if strategy == 'ddp':
        strategy = DDPStrategy()
    elif strategy == 'colossalai_gemini':
        strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
    elif strategy == 'colossalai_zero2':
        strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{strategy}"')

    with strategy.model_init_context():
        actor = GPTActor(config=GPT_CONFIG).cuda()

    actor_optim = HybridAdam(actor.parameters())

    actor, actor_optim = strategy.prepare((actor, actor_optim))

    def run_step():
        data = get_data(BATCH_SIZE)
        action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
        action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
        loss = action_log_probs.sum()
        strategy.backward(loss, actor, actor_optim)
        strategy.optimizer_step(actor_optim)

    run_step()

    ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()

    with ctx as dirname:
        rank0_dirname = [dirname]
        dist.broadcast_object_list(rank0_dirname)
        rank0_dirname = rank0_dirname[0]

        model_path = os.path.join(rank0_dirname, 'model.pt')
        optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')

        strategy.save_model(actor, model_path, only_rank0=True)
        strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)

        dist.barrier()

        strategy.load_model(actor, model_path, strict=False)
        strategy.load_optimizer(actor_optim, optim_path)

        dist.barrier()

    run_step()


def run_dist(rank, world_size, port, strategy):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)
    run_test_checkpoint(strategy)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_checkpoint(2, 'colossalai_zero2')