utils.py 5.26 KB
Newer Older
1
2
import os
from collections import OrderedDict
3
from typing import Any, Dict
4
5
6
7
8
9
10
11

import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
12
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
13
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28


def is_rank_0() -> bool:
    return not dist.is_initialized() or dist.get_rank() == 0


def get_rank() -> int:
    return dist.get_rank() if dist.is_initialized() else 0


def get_world_size() -> int:
    return dist.get_world_size() if dist.is_initialized() else 1


def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
29
    if model == "gpt2":
30
        actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
31
    elif model == "bloom":
32
        actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
33
    elif model == "opt":
34
        actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
35
    elif model == "llama":
36
37
38
39
40
41
42
        actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
    else:
        raise ValueError(f'Unsupported actor model "{model}"')
    return actor


def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
43
    if model == "gpt2":
44
        critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
45
    elif model == "bloom":
46
        critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
47
    elif model == "opt":
48
        critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
49
    elif model == "llama":
50
        critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
51
52
53
54
55
56
    else:
        raise ValueError(f'Unsupported reward model "{model}"')
    return critic


def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
57
    if model == "gpt2":
58
        reward_model = GPTRM(pretrained=pretrained, config=config)
59
    elif model == "bloom":
60
        reward_model = BLOOMRM(pretrained=pretrained, config=config)
61
    elif model == "opt":
62
        reward_model = OPTRM(pretrained=pretrained, config=config)
63
    elif model == "llama":
64
65
66
67
68
69
70
        reward_model = LlamaRM(pretrained=pretrained, config=config)
    else:
        raise ValueError(f'Unsupported reward model "{model}"')
    return reward_model


def get_strategy_from_args(strategy: str):
71
    if strategy == "ddp":
72
        strategy_ = DDPStrategy()
73
    elif strategy == "colossalai_gemini":
74
        strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
75
76
77
    elif strategy == "colossalai_zero2":
        strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
    elif strategy == "colossalai_gemini_cpu":
78
79
80
        strategy_ = GeminiStrategy(
            placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5
        )
81
82
    elif strategy == "colossalai_zero2_cpu":
        strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
83
84
85
86
87
88
    else:
        raise ValueError(f'Unsupported strategy "{strategy}"')
    return strategy_


def get_tokenizer_from_args(model: str, **kwargs):
89
90
91
92
93
    if model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    elif model == "bloom":
        tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
    elif model == "opt":
94
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
95
    elif model == "llama":
96
97
98
99
100
101
102
103
104
105
        pretrain_path = kwargs["pretrain"]
        tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
    else:
        raise ValueError(f'Unsupported model "{model}"')

    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def set_dist_env(env_info: Dict[str, str]):
106
107
108
109
110
    os.environ["RANK"] = env_info["rank"]
    os.environ["LOCAL_RANK"] = env_info["local_rank"]
    os.environ["WORLD_SIZE"] = env_info["world_size"]
    os.environ["MASTER_PORT"] = env_info["master_port"]
    os.environ["MASTER_ADDR"] = env_info["master_addr"]
111
112
113
114
115
116
117
118
119
120


def get_model_numel(model: nn.Module) -> int:
    numel = sum(p.numel() for p in model.parameters())
    return numel


def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
    target_receivers = []
    if num_senders <= num_receivers or allow_idle_sender:
121
        # a sender will send data to one or more receivers
122
123
124
125
126
127
128
129
130
131
132
        # a receiver only has one sender
        for i in range(num_receivers):
            if i % num_senders == sender_idx:
                target_receivers.append(i)
    else:
        # a sender will send data to one receiver
        # a receiver may have more than one sender
        target_receivers.append(sender_idx % num_receivers)
    return target_receivers


133
134
135
136
137
138
def state_dict_to(
    state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
):
    """
    keep state_dict intact
    """
139
140
141
142
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k] = v.to(dtype=dtype, device=device)
    return new_state_dict