pipeline_strategy.py 3.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# WIP


from coati.trainer.strategies import Strategy
from coati.trainer.strategies import NaiveStrategy
from coati.models.base import Actor, RewardModel, Critic

import numpy as np
import torch
from torch._C._distributed_rpc import _is_current_rpc_agent_set

import colossalai
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.pipeline.middleware.adaptor import get_fx_topology


import os
from functools import partial
import random

rpc_is_initialized = _is_current_rpc_agent_set

class PipelineModel(torch.nn.Module):
    '''
    Actor has 2 kinds of jobs: forward and generate. 
        better to just pipelinize the inner model
    '''
    def __init__(self,
                 model: torch.nn.Module,
                 stage_num: int,
                 num_microbatches: int,
                 data_kwargs = None,
                 ):
        super().__init__()
        # create partition module
        def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs):
            model.eval()
            tracer = ColoTracer()
            meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
            graph = tracer.trace(root=model, meta_args=meta_args)
            gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
            annotated_model = balanced_split_pass(gm, stage_num)
            top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
            topo = get_fx_topology(top_module)
            for submodule in split_submodules:
                if isinstance(submodule, torch.fx.GraphModule):
                    setattr(submodule, '_topo', topo)
            return split_submodules[pp_rank + 1]
    
        def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
            partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
            return partition
        self.inference_engine = OneFOneBPipelineEngine(
            partition_fn=partial(partition, model, data_kwargs),
            stage_num=stage_num,
            num_microbatches=num_microbatches,
            device='cuda',
        )

    def forward(self,
                **model_inputs):
        return self.inference_engine.forward_backward(**model_inputs, forward_only=True)



class PPStrategy(NaiveStrategy):
    """
        Strategy for Pipeline inference (inference only!)
        
        master node only
    """
    def __init__(
        self,
        seed: int = 42
    ):
        self.seed = seed
        super().__init__()
        
        
    def setup_distributed(self) -> None:
        colossalai.launch_from_torch({}, seed=self.seed)
        ppg.set_global_info(rank = int(os.environ['RANK']),
                            world_size=int(os.environ['WORLD_SIZE']),
                            dp_degree=1,
                            tp_degree=1,
                            num_worker_threads=128,
                            device="cuda")
        
    def model_init_context(self):
        return super().model_init_context()
    
    def setup_model(self, model: torch.nn.Module) -> torch.nn.Module:
        if isinstance(model, Actor) or \
            isinstance(model, RewardModel) or \
            isinstance(model, Critic):
            model.model = PipelineModel(model.model)

    def set_seed(self, seed: int) -> None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)