test_experience.py 4.67 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
import os
from copy import deepcopy

import pytest
import torch
import torch.distributed as dist
7
from coati.experience_buffer import NaiveExperienceBuffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
9
10
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
11
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
12
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

15
from colossalai.testing import rerun_if_address_is_in_use, spawn
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)


def get_data(batch_size: int, seq_len: int = 10) -> dict:
    input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
    attention_mask = torch.ones_like(input_ids)
    return dict(input_ids=input_ids, attention_mask=attention_mask)


def gather_and_equal(tensor: torch.Tensor) -> bool:
    world_size = dist.get_world_size()
    outputs = [torch.empty_like(tensor) for _ in range(world_size)]
    dist.all_gather(outputs, tensor.contiguous())
    for t in outputs[1:]:
        if not torch.equal(outputs[0], t):
            return False
    return True


36
def make_and_consume_experience(strategy):
37
    EXPERIENCE_BATCH_SIZE = 4
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
38
39
40
41
    SAMPLE_BATCH_SIZE = 2

    if strategy == 'ddp':
        strategy = DDPStrategy()
42
43
44
    elif strategy == 'colossalai-zero2':
        strategy = LowLevelZeroStrategy()
    elif strategy == 'colossalai-gemini':
45
        strategy = GeminiStrategy(placement_policy='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
47
48
49
50
51
52
53
54
55
    else:
        raise ValueError(f'Unsupported strategy "{strategy}"')

    actor = GPTActor(config=GPT_CONFIG).cuda()
    critic = GPTCritic(config=GPT_CONFIG).cuda()

    initial_model = deepcopy(actor)
    reward_model = RewardModel(deepcopy(critic.model)).cuda()

    experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
56
    data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
57
58
59

    # experience of all ranks should be the same
    for _ in range(2):
60
        data = get_data(EXPERIENCE_BATCH_SIZE)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        assert gather_and_equal(data['input_ids'])
        assert gather_and_equal(data['attention_mask'])
        experience = experience_maker.make_experience(**data,
                                                      do_sample=True,
                                                      max_length=16,
                                                      eos_token_id=50256,
                                                      pad_token_id=50256)
        assert gather_and_equal(experience.sequences)
        assert gather_and_equal(experience.action_log_probs)
        assert gather_and_equal(experience.values)
        assert gather_and_equal(experience.reward)
        assert gather_and_equal(experience.advantages)
        assert gather_and_equal(experience.action_mask)
        assert gather_and_equal(experience.attention_mask)
75
        data_buffer.append(experience)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
76

77
78
    # data buffer's data should be the same
    buffer_size = torch.tensor([len(data_buffer)], device='cuda')
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
79
    assert gather_and_equal(buffer_size)
80
    for item in data_buffer.items:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
81
82
83
84
85
86
87
88
89
        assert gather_and_equal(item.sequences)
        assert gather_and_equal(item.action_log_probs)
        assert gather_and_equal(item.values)
        assert gather_and_equal(item.reward)
        assert gather_and_equal(item.advantages)
        assert gather_and_equal(item.action_mask)
        assert gather_and_equal(item.attention_mask)

    # dataloader of each rank should have the same size and different batch
90
    dataloader = strategy.setup_dataloader(data_buffer)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    dataloader_size = torch.tensor([len(dataloader)], device='cuda')
    assert gather_and_equal(dataloader_size)
    for experience in dataloader:
        assert not gather_and_equal(experience.sequences)
        assert not gather_and_equal(experience.action_log_probs)
        assert not gather_and_equal(experience.values)
        assert not gather_and_equal(experience.reward)
        assert not gather_and_equal(experience.advantages)
        # action mask and attention mask may be same


def run_dist(rank, world_size, port, strategy):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)
108
    make_and_consume_experience(strategy)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
109
110
111
112


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
113
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini'])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
114
@rerun_if_address_is_in_use()
115
def test_experience(world_size, strategy):
116
    spawn(run_dist, world_size, strategy=strategy)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
117
118
119


if __name__ == '__main__':
120
    test_experience(2, 'colossalai')