train.py 10.7 KB
Newer Older
Xuanlei Zhao's avatar
Xuanlei Zhao committed
1
2
3
4
import argparse

import torch
import torch.distributed as dist
Hongxin Liu's avatar
Hongxin Liu committed
5
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
Xuanlei Zhao's avatar
Xuanlei Zhao committed
6
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
Hongxin Liu's avatar
Hongxin Liu committed
7
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
Xuanlei Zhao's avatar
Xuanlei Zhao committed
8
9
10
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
Hongxin Liu's avatar
Hongxin Liu committed
11
from transformers.models.mixtral import MixtralForCausalLM
Xuanlei Zhao's avatar
Xuanlei Zhao committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device


@torch.no_grad()
def get_global_loss(loss, booster):
    global_loss = loss.clone().detach()
    dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
    global_loss.div_(booster.plugin.dp_size)
    return global_loss


class RandomDataset(Dataset):
    def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
        self.attention_mask = torch.ones_like(self.input_ids)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.input_ids[idx],
        }


def parse_args():
    # basic settings
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="mistralai/Mixtral-8x7B-v0.1",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
    parser.add_argument(
        "--plugin",
        type=str,
        default="hybrid",
        choices=["hybrid"],
        help="Parallel methods.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="./outputs",
        help="The path of your saved model after finetuning.",
    )
    parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="Batch size (per dp group) for the training dataloader.",
    )
    parser.add_argument(
        "--save_interval",
        type=int,
        default=1000,
        help=" The interval (steps) of saving checkpoints.",
    )
    parser.add_argument(
        "--precision",
        type=str,
        default="bf16",
        choices=["fp32", "bf16", "fp16"],
        help="The mixed precision training.",
    )
    parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")

    # optim
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")

    # lr scheduler
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
    parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")

    # zero stage for all plugins
    parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
    # hybrid plugin
    parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
    parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
    parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
    parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")

    # kernel
    parser.add_argument(
        "--use_kernel",
        action="store_true",
        help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
    )
    parser.add_argument(
        "--use_layernorm_kernel",
        action="store_true",
        help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
    )

    # load balance
    parser.add_argument(
        "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
    )
    parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
    # communicate overlap
    parser.add_argument(
        "--comm_overlap",
        action="store_true",
        help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
    )
    # hierarchical all-to-all
    parser.add_argument(
        "--hierarchical_alltoall",
        action="store_true",
        help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()

    # Set plugin
    if args.plugin == "hybrid":
        plugin = MoeHybridParallelPlugin(
Hongxin Liu's avatar
Hongxin Liu committed
154
            tp_size=1,
Xuanlei Zhao's avatar
Xuanlei Zhao committed
155
            pp_size=args.pp_size,
Hongxin Liu's avatar
Hongxin Liu committed
156
            ep_size=args.ep_size,
Xuanlei Zhao's avatar
Xuanlei Zhao committed
157
            microbatch_size=args.microbatch_size,
Hongxin Liu's avatar
Hongxin Liu committed
158
159
160
161
162
163
            custom_policy=MixtralForCausalLMPolicy(),
            enable_fused_normalization=args.use_layernorm_kernel,
            enable_jit_fused=args.use_kernel,
            precision=args.precision,
            zero_stage=args.zero_stage,
            checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
Xuanlei Zhao's avatar
Xuanlei Zhao committed
164
        )
Hongxin Liu's avatar
Hongxin Liu committed
165

Xuanlei Zhao's avatar
Xuanlei Zhao committed
166
167
168
169
170
    else:
        raise ValueError(f"Invalid plugin {args.plugin}")
    coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")

    # Build Mixtral model
Hongxin Liu's avatar
Hongxin Liu committed
171
172
    model = MixtralForCausalLM.from_pretrained(args.model_name)
    coordinator.print_on_master(f"Finish init model")
Xuanlei Zhao's avatar
Xuanlei Zhao committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Prepare tokenizer and dataloader
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
    collate_fn = None
    dataloader = plugin.prepare_dataloader(
        dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
    )

    # Set optimizer
    optimizer = HybridAdam(
        model_params=model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.95),
        weight_decay=args.weight_decay,
        adamw_mode=True,
    )

    # Set lr scheduler
    lr_scheduler = CosineAnnealingWarmupLR(
        optimizer=optimizer,
        total_steps=args.num_epochs * len(dataloader),
        warmup_steps=args.warmup_steps
        if args.warmup_steps is not None
        else int(args.num_epochs * len(dataloader) * 0.025),
        eta_min=0.1 * args.lr,
    )

    # Set booster
Hongxin Liu's avatar
Hongxin Liu committed
205
    booster = Booster(plugin=plugin)
Xuanlei Zhao's avatar
Xuanlei Zhao committed
206
207
208
209
210
211
212
213
214
215
216
    model, optimizer, _, dataloader, lr_scheduler = booster.boost(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        dataloader=dataloader,
    )
    use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
    is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
    coordinator.print_on_master(f"Finish init booster")

    # Load ckpt
Hongxin Liu's avatar
Hongxin Liu committed
217
    if args.load_checkpoint is not None:
Xuanlei Zhao's avatar
Xuanlei Zhao committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
        coordinator.print_on_master(f"Finish load optimizer")

    # Start finetuning
    coordinator.print_on_master(f"Start finetuning")
    for epoch in range(args.num_epoch):
        model.train()
        train_dataloader_iter = iter(dataloader)
        total_len = len(train_dataloader_iter)
        with tqdm(
            range(total_len),
            desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
            disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
        ) as pbar:
            for step in pbar:
                if use_pipeline:
                    # Forward pass
                    outputs = booster.execute_pipeline(
                        train_dataloader_iter,
                        model,
                        lambda x, y: x.loss,
                        optimizer,
                        return_loss=True,
                    )
                    # Backward and optimize
                    if is_pp_last_stage:
                        loss = outputs["loss"]
                        global_loss = get_global_loss(loss, booster)
                        if coordinator._local_rank == "0":
                            pbar.set_postfix({"Loss": global_loss.item()})
                else:
                    # Forward pass
                    data = next(train_dataloader_iter)
                    data = move_to_cuda(data, torch.cuda.current_device())
                    outputs = model(**data)
                    loss = outputs["loss"]
                    # Backward
                    booster.backward(loss, optimizer)
                    pbar.set_postfix({"loss": loss.item()})

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                # Apply load balance
Hongxin Liu's avatar
Hongxin Liu committed
263
264
265
266
267
268
269
                # if (
                #     args.load_balance
                #     and args.load_balance_interval > 0
                #     and (step + 1) % args.load_balance_interval == 0
                # ):
                #     coordinator.print_on_master(f"Apply load balance")
                #     apply_load_balance(model, optimizer)
Xuanlei Zhao's avatar
Xuanlei Zhao committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
                # save ckeckpoint
                if (step + 1) % args.save_interval == 0:
                    coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
                    save_checkpoint(
                        args.output_path,
                        booster,
                        model,
                        optimizer,
                        lr_scheduler,
                        epoch,
                        step,
                        args.batch_size,
                        coordinator,
                    )

        # save checkpoint at the end of each epochs
        booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
        coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")

    # Finish training
    coordinator.print_on_master(f"Finish training")


if __name__ == "__main__":
    main()