naive.py 4.99 KB
Newer Older
1
2
3
4
import os
import sys
from collections import OrderedDict
from typing import Any, Dict, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
5
6

import torch
7
import torch.distributed as dist
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
9
import torch.nn as nn
import torch.optim as optim
10
from coati.models.base import get_base_model
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
11
from coati.replay_buffer import ReplayBuffer
12
13
14
from coati.models.base import RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
15
16
from torch.optim import Optimizer
from torch.utils.data import DataLoader
17
from transformers.modeling_utils import PreTrainedModel
18
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
19
20
21
22

from .base import Strategy


23
24
25
26
27
28
29
30
31
# TODO Move this to a util.py   (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
    state_dict = OrderedDict()
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            state_dict[name] = parameter.detach()
    return state_dict


Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
32
33
34
35
36
37
38
39
40
41
42
43
class NaiveStrategy(Strategy):
    """
        Strategy for single GPU. No parallelism is used.
    """

    def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
        loss.backward()

    def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
        optimizer.step()

    def setup_distributed(self) -> None:
44
        self._try_init_dist(force=False)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    def setup_model(self, model: nn.Module) -> nn.Module:
        return model

    def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
        return optimizer

    def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
        return DataLoader(replay_buffer,
                          batch_size=replay_buffer.sample_batch_size,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=pin_memory,
                          collate_fn=replay_buffer.collate_fn)

60
    def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
61
        state_dict = model.state_dict()
62
        torch.save(state_dict, path)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
63
64

    def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
65
        unwrapped_model = self.unwrap_model(model)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
66
        state_dict = torch.load(path, map_location=map_location)
67
        unwrapped_model.load_state_dict(state_dict, strict=strict)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
68
69
70
71
72
73
74

    def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
        torch.save(optimizer.state_dict(), path)

    def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
        state_dict = torch.load(path, map_location=map_location)
        optimizer.load_state_dict(state_dict)
75
76
77
78
79
80
81
82
83
84
85

    def save_pretrained(self,
                        model: nn.Module,
                        path: str,
                        only_rank0: bool = True,
                        tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
        unwrapped_model = self.unwrap_model(model)
        assert isinstance(unwrapped_model, PreTrainedModel)
        unwrapped_model.save_pretrained(path)
        if tokenizer is not None:
            tokenizer.save_pretrained(path)
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    def get_model_state_dict_shard(self, model: nn.Module, **config):
        # TODO: implement sharding on naive strategy
        model = self.unwrap_model(model)
        if 'requires_grad_only' in config and config['requires_grad_only'] == True:
            state_dict = get_grad_required_state_dict(model)
        else:
            state_dict = model.state_dict()

        if 'shard_size' in config:
            shard_size = config['shard_size']
            accumulate_size = 0
            state_dict_shard = OrderedDict()
            for name, param in state_dict.items():
                state_dict_shard[name] = param
                accumulate_size += param.numel() * param.element_size()
                if accumulate_size >= shard_size:
                    accumulate_size = 0
                    yield state_dict_shard
                    state_dict_shard = OrderedDict()
            if accumulate_size > 0:
                yield state_dict_shard
        else:
            yield state_dict

    def _try_init_dist(self, force: bool = False) -> None:
        try:
            rank = int(os.environ['RANK'])
            local_rank = int(os.environ['LOCAL_RANK'])
            world_size = int(os.environ['WORLD_SIZE'])
            host = os.environ['MASTER_ADDR']
            port = int(os.environ['MASTER_PORT'])
            dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
            torch.cuda.set_device(local_rank)
        except KeyError as e:
            if force:
                raise RuntimeError(
                    f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
                )
        except Exception as e:
            if force:
                raise e