mmmt_dummy.py 7.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import os
import socket
from functools import partial

import ray
import torch
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
    get_actor_from_args,
    get_critic_from_args,
    get_receivers_per_sender,
    get_reward_model_from_args,
    get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights


def get_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
25
        s.bind(("", 0))
26
27
28
29
30
        return s.getsockname()[1]


def get_local_ip():
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
31
        s.connect(("8.8.8.8", 80))
32
33
34
35
36
37
38
        return s.getsockname()[0]


def main(args):
    master_addr = str(get_local_ip())
    # trainer_env_info
    trainer_port = str(get_free_port())
39
40
41
42
43
44
45
46
47
48
    env_info_trainers = [
        {
            "local_rank": "0",
            "rank": str(rank),
            "world_size": str(args.num_trainers),
            "master_port": trainer_port,
            "master_addr": master_addr,
        }
        for rank in range(args.num_trainers)
    ]
49
50
51

    # maker_env_info
    maker_port = str(get_free_port())
52
53
54
55
56
57
58
59
60
61
    env_info_makers = [
        {
            "local_rank": "0",
            "rank": str(rank),
            "world_size": str(args.num_makers),
            "master_port": maker_port,
            "master_addr": master_addr,
        }
        for rank in range(args.num_makers)
    ]
62
63
64
65
66
67
68
69
70
71

    # configure tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
    tokenizer.pad_token = tokenizer.eos_token

    def model_fn():
        actor_cfg = AutoConfig.from_pretrained(args.pretrain)
        critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
        actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
        critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
72
73
74
75
        reward_model = (
            get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
        )
        if args.initial_model_quant_ckpt is not None and args.model == "llama":
76
77
78
            # quantize initial model
            with low_resource_init(), no_init_weights():
                initial_model = get_actor_from_args(args.model, config=actor_cfg)
79
80
81
82
83
84
85
            initial_model.model = (
                llama_load_quant(
                    initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
                )
                .cuda()
                .requires_grad_(False)
            )
86
87
88
89
90
91
92
93
        else:
            initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
        return actor, critic, reward_model, initial_model

    # configure Experience Maker
    experience_holder_refs = [
        ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
            detached_trainer_name_list=[
94
                f"trainer{x}"
95
96
97
98
99
100
101
                for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
            ],
            strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
            model_fn=model_fn,
            env_info=env_info_maker,
            kl_coef=0.1,
            debug=args.debug,
102
103
            # sync_models_from_trainers=True,
            # generation kwargs:
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            max_length=512,
            do_sample=True,
            temperature=1.0,
            top_k=50,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            eval_performance=True,
            use_cache=True,
        )
        for i, env_info_maker in enumerate(env_info_makers)
    ]

    def trainer_model_fn():
        actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
118
119
120
121
122
        critic = (
            get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
            .half()
            .cuda()
        )
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        return actor, critic

    # configure Trainer
    trainer_refs = [
        DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
            experience_maker_holder_name_list=[
                f"maker{x}"
                for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
            ],
            strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
            model_fn=trainer_model_fn,
            env_info=env_info_trainer,
            train_batch_size=args.train_batch_size,
            buffer_limit=16,
            eval_performance=True,
            debug=args.debug,
        )
        for i, env_info_trainer in enumerate(env_info_trainers)
    ]

    dataset_size = args.experience_batch_size * 4

    def data_gen_fn():
        input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
        attn_mask = torch.ones_like(input_ids)
148
        return {"input_ids": input_ids, "attention_mask": attn_mask}
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    def build_dataloader(size):
        dataset = [data_gen_fn() for _ in range(size)]
        dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
        return dataloader

    # uncomment this function if sync_models_from_trainers is True
    # ray.get([
    #     trainer_ref.sync_models_to_remote_makers.remote()
    #     for trainer_ref in trainer_refs
    # ])

    wait_tasks = []

    for experience_holder_ref in experience_holder_refs:
        wait_tasks.append(
165
166
167
168
            experience_holder_ref.workingloop.remote(
                partial(build_dataloader, dataset_size), num_steps=args.experience_steps
            )
        )
169

170
171
172
173
174
175
    total_steps = (
        args.experience_batch_size
        * args.experience_steps
        * args.num_makers
        // (args.num_trainers * args.train_batch_size)
    )
176
177
178
179
180
181
    for trainer_ref in trainer_refs:
        wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))

    ray.get(wait_tasks)


182
if __name__ == "__main__":
183
    parser = argparse.ArgumentParser()
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    parser.add_argument("--num_makers", type=int, default=1)
    parser.add_argument("--num_trainers", type=int, default=1)
    parser.add_argument(
        "--trainer_strategy",
        choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
        default="ddp",
    )
    parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
    parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--critic_pretrain", type=str, default=None)
    parser.add_argument("--experience_steps", type=int, default=4)
    parser.add_argument("--experience_batch_size", type=int, default=8)
    parser.add_argument("--train_epochs", type=int, default=1)
    parser.add_argument("--update_steps", type=int, default=2)
    parser.add_argument("--train_batch_size", type=int, default=8)
    parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")

    parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
    parser.add_argument("--quant_bits", type=int, default=4)
    parser.add_argument("--quant_group_size", type=int, default=128)
    parser.add_argument("--debug", action="store_true")
207
208
209
    args = parser.parse_args()
    ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
    main(args)