trainer.py 6.6 KB
Newer Older
chenzk's avatar
v1.0.3  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Optional, Type, Union

from config import ExistingCheckpointInit, MambaConfig, MambaInit
from torch.nn.parallel import DistributedDataParallel

from nanotron import logging
from nanotron.trainer import DistributedTrainer

logger = logging.get_logger(__name__)

from nanotron import distributed as dist
from nanotron.config import ParallelismArgs
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of
from nanotron.parallel.tensor_parallel.nn import (
    TensorParallelLinearMode,
    TensorParallelRowLinear,
)
from nanotron.parallel.tied_parameters import (
    create_pg_for_tied_weights,
    get_tied_id_to_param,
    tie_parameters,
)
from nanotron.serialize import load_weights, parse_ckpt_path


class MambaTrainer(DistributedTrainer):
    def __init__(
        self,
        config_or_config_file: Union[MambaConfig, str],
        config_class: Type[MambaConfig] = MambaConfig,
        model_config_class: Optional[Type] = None,
        model_class: Type[NanotronModel] = None,
    ):
        assert config_class == MambaConfig
        super().__init__(config_or_config_file, config_class, model_config_class, model_class)

    def _mark_tied_parameters(
        self,
        model: NanotronModel,
        parallel_context: ParallelContext,
        parallel_config: Optional[ParallelismArgs] = None,
    ):
        # Tie embeddings
        embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names()
        if len(embeddings_lm_head_tied_names) > 0:
            shared_embeddings = [
                (
                    target,
                    (
                        parallel_context.world_rank_matrix[
                            dist.get_rank(parallel_context.expert_pg),
                            get_pp_rank_of(target, module=model),
                            dist.get_rank(parallel_context.dp_pg),
                            dist.get_rank(parallel_context.tp_pg),
                        ],
                    ),
                )
                for target in embeddings_lm_head_tied_names
            ]
            tie_parameters(
                root_module=model,
                ties=shared_embeddings,
                parallel_context=parallel_context,
                reduce_op=dist.ReduceOp.SUM,
            )

        # Tie custom params
        model.tie_custom_params()

        # Sync all parameters that have the same name and that are not sharded
        assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point"
        for module_name, module in model.named_modules():
            for param_name, param in module.named_parameters(recurse=False):
                name = f"{module_name}.{param_name}"

                if isinstance(param, NanotronParameter) and (param.is_sharded or param.is_tied):
                    continue

                if isinstance(module, TensorParallelRowLinear) and "bias" == param_name:
                    # bias for TensorParallelRowLinear only exists on TP=0 so we don't need to tie it
                    continue

                shared_weights = [
                    (
                        name,
                        # sync across TP group
                        tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
                    )
                ]

                if (
                    parallel_config is None
                    or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE
                    or hasattr(model.config.model.model_config, "is_mamba_config")
                ):
                    # We add `reduce_op=None` in order to signal that the weight are synced by design without needing to reduce
                    # when TP=2 we have LN that is duplicated across TP, so by design it's tied
                    reduce_op = None
                else:
                    reduce_op = dist.ReduceOp.SUM

                tie_parameters(
                    root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
                )

        create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context)

    def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
        unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model

        # Load or initialize model weights
        self.init_checkpoint_path = parse_ckpt_path(config=self.config)
        reloaded_from_checkpoint = False
        if self.init_checkpoint_path is not None:
            # Reload from a training checkpoint
            log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0)
            self.param_shard_metadata = load_weights(
                model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
            )
            reloaded_from_checkpoint = True
        if not reloaded_from_checkpoint:
            log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO)
            if isinstance(self.config.model.init_method, ExistingCheckpointInit):
                # Initialize model from an pretrained model checkpoint
                self.param_shard_metadata = load_weights(
                    model=unwrapped_model,
                    parallel_context=self.parallel_context,
                    root_folder=self.config.model.init_method.path,
                )
            elif isinstance(self.config.model.init_method, MambaInit):

                unwrapped_model.init_model_randomly(config=self.config)
                # Synchronize parameters so that the model is consistent
                # sync all params across dp
                for name, param in sorted(model.named_parameters(), key=lambda x: x[0]):
                    dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg)

                # sync tied params across tied groups
                for (_, group_ranks), param in sorted(
                    get_tied_id_to_param(
                        parameters=model.parameters(),
                        root_module=unwrapped_model,
                    ).items(),
                    key=lambda x: x[0],
                ):
                    group = self.parallel_context.world_ranks_to_pg[group_ranks]
                    dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
            else:
                raise ValueError(f"Unsupported {self.config.model.init_method}")

        return model