utils.py 2.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch.distributed as dist
from typing import Any, Callable, Dict, List, Optional
from coati.models.bloom import BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTActor, GPTCritic
from coati.models.opt import OPTActor, OPTCritic
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
import torch
import os

def is_rank_0() -> bool:
    return not dist.is_initialized() or dist.get_rank() == 0


def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0):
    if model == 'gpt2':
        actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
        critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
    elif model == 'bloom':
        actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
        critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
    elif model == 'opt':
        actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
        critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
    else:
        raise ValueError(f'Unsupported model "{model}"')
    return actor, critic


def get_strategy_from_args(strategy: str):
    if strategy == 'naive':
        strategy_ = NaiveStrategy()
    elif strategy == 'ddp':
        strategy_ = DDPStrategy()
    elif strategy == 'colossalai_gemini':
        strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
    elif strategy == 'colossalai_zero2':
        strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
    else:
        raise ValueError(f'Unsupported strategy "{strategy}"')
    return strategy_


def set_dist_env(env_info: Dict[str, str]):
    os.environ["RANK"] = env_info['rank']
    os.environ["LOCAL_RANK"] = env_info['local_rank']
    os.environ["WORLD_SIZE"] = env_info['world_size']
    os.environ['MASTER_PORT'] = env_info['master_port']
    os.environ['MASTER_ADDR'] = env_info['master_addr']