test_experience.py 4.48 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
5
6
import os
from copy import deepcopy

import pytest
import torch
import torch.distributed as dist
7
from coati.experience_buffer import NaiveExperienceBuffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
9
10
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
11
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
12
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
13
14
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

15
from colossalai.testing import rerun_if_address_is_in_use, spawn
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
16
17
18
19
20

GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)


def get_data(batch_size: int, seq_len: int = 10) -> dict:
21
    input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    attention_mask = torch.ones_like(input_ids)
    return dict(input_ids=input_ids, attention_mask=attention_mask)


def gather_and_equal(tensor: torch.Tensor) -> bool:
    world_size = dist.get_world_size()
    outputs = [torch.empty_like(tensor) for _ in range(world_size)]
    dist.all_gather(outputs, tensor.contiguous())
    for t in outputs[1:]:
        if not torch.equal(outputs[0], t):
            return False
    return True


36
def make_and_consume_experience(strategy):
37
    EXPERIENCE_BATCH_SIZE = 4
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
38
39
    SAMPLE_BATCH_SIZE = 2

40
    if strategy == "ddp":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
41
        strategy = DDPStrategy()
42
    elif strategy == "colossalai-zero2":
43
        strategy = LowLevelZeroStrategy()
44
45
    elif strategy == "colossalai-gemini":
        strategy = GeminiStrategy(placement_policy="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
46
47
48
49
50
51
52
53
54
55
    else:
        raise ValueError(f'Unsupported strategy "{strategy}"')

    actor = GPTActor(config=GPT_CONFIG).cuda()
    critic = GPTCritic(config=GPT_CONFIG).cuda()

    initial_model = deepcopy(actor)
    reward_model = RewardModel(deepcopy(critic.model)).cuda()

    experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
56
    data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
57
58
59

    # experience of all ranks should be the same
    for _ in range(2):
60
        data = get_data(EXPERIENCE_BATCH_SIZE)
61
62
63
64
65
        assert gather_and_equal(data["input_ids"])
        assert gather_and_equal(data["attention_mask"])
        experience = experience_maker.make_experience(
            **data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
        )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
66
67
68
69
70
71
72
        assert gather_and_equal(experience.sequences)
        assert gather_and_equal(experience.action_log_probs)
        assert gather_and_equal(experience.values)
        assert gather_and_equal(experience.reward)
        assert gather_and_equal(experience.advantages)
        assert gather_and_equal(experience.action_mask)
        assert gather_and_equal(experience.attention_mask)
73
        data_buffer.append(experience)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
74

75
    # data buffer's data should be the same
76
    buffer_size = torch.tensor([len(data_buffer)], device="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
77
    assert gather_and_equal(buffer_size)
78
    for item in data_buffer.items:
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
79
80
81
82
83
84
85
86
87
        assert gather_and_equal(item.sequences)
        assert gather_and_equal(item.action_log_probs)
        assert gather_and_equal(item.values)
        assert gather_and_equal(item.reward)
        assert gather_and_equal(item.advantages)
        assert gather_and_equal(item.action_mask)
        assert gather_and_equal(item.attention_mask)

    # dataloader of each rank should have the same size and different batch
88
    dataloader = strategy.setup_dataloader(data_buffer)
89
    dataloader_size = torch.tensor([len(dataloader)], device="cuda")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
90
91
92
93
94
95
96
97
98
99
100
    assert gather_and_equal(dataloader_size)
    for experience in dataloader:
        assert not gather_and_equal(experience.sequences)
        assert not gather_and_equal(experience.action_log_probs)
        assert not gather_and_equal(experience.values)
        assert not gather_and_equal(experience.reward)
        assert not gather_and_equal(experience.advantages)
        # action mask and attention mask may be same


def run_dist(rank, world_size, port, strategy):
101
102
103
104
105
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
106
    make_and_consume_experience(strategy)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
107
108
109


@pytest.mark.dist
110
111
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
112
@rerun_if_address_is_in_use()
113
def test_experience(world_size, strategy):
114
    spawn(run_dist, world_size, strategy=strategy)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
115
116


117
118
if __name__ == "__main__":
    test_experience(2, "colossalai")