finetune.py 11 KB
Newer Older
1
import argparse
2
from typing import Callable, List, Union
3
4
5
6
7
8

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
9
from torch.optim import Optimizer
10
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
11
12
13
14
15
16
17
18
19
20
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AlbertForSequenceClassification,
    AutoConfig,
    BertForSequenceClassification,
    get_linear_schedule_with_warmup,
)

import colossalai
21
from colossalai.accelerator import get_accelerator
22
from colossalai.booster import Booster
23
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
24
25
26
27
28
29
30
31
32
33
34
35
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1

36
37
38
output_transform_fn = lambda x: x
criterion = lambda x: x.loss

39
40
41
42
43
44

def move_to_cuda(batch):
    return {k: v.cuda() for k, v in batch.items()}


@torch.no_grad()
45
46
47
48
49
50
51
52
53
54
def evaluate_model(
    model: nn.Module,
    criterion,
    test_dataloader: Union[DataLoader, List[DataLoader]],
    num_labels: int,
    task_name: str,
    eval_splits: List[str],
    booster: Booster,
    coordinator: DistCoordinator,
):
55
56
57
58
    metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
    model.eval()

    def evaluate_subset(dataloader: DataLoader):
59
        use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
60
        is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
61

62
        accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
63
64
65
        for batch in dataloader:
            batch = move_to_cuda(batch)
            labels = batch["labels"]
66
            if use_pipeline:
67
68
69
70
71
                pg_mesh = booster.plugin.pg_mesh
                pp_group = booster.plugin.pp_group
                current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
                current_rank = dist.get_rank()
                batch = iter([batch])
72

73
                outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
74

75
                if is_pp_last_device:
76
                    logits = outputs["outputs"]["logits"]
77
                    val_loss = outputs["loss"]
78
79
80
81
82
83
84
                    accum_loss.add_(val_loss)

                    if num_labels > 1:
                        preds = torch.argmax(logits, axis=1)
                    elif num_labels == 1:
                        preds = logits.squeeze()

85
                    dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)
86
87
88

                    metric.add_batch(predictions=preds, references=labels)
                elif current_rank in current_pp_group_ranks:
89
90
                    object_list = [None, None]
                    dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
91

92
93
94
95
                    metric.add_batch(
                        predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
                    )
                    accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
96
97
98
99
100
101
102
103
104
105
106
107
108

            else:
                batch = move_to_cuda(batch)
                outputs = model(**batch)
                val_loss, logits = outputs[:2]
                accum_loss.add_(val_loss)

                if num_labels > 1:
                    preds = torch.argmax(logits, axis=1)
                elif num_labels == 1:
                    preds = logits.squeeze()

                metric.add_batch(predictions=preds, references=labels)
109
110
111

        results = metric.compute()
        dist.all_reduce(accum_loss.div_(len(dataloader)))
112
        if coordinator.is_master() and results is not None:
113
            results["loss"] = accum_loss.item() / coordinator.world_size
114

115
116
117
118
119
120
121
122
123
        return results

    if isinstance(test_dataloader, DataLoader):
        return evaluate_subset(test_dataloader)
    else:
        assert len(test_dataloader) == len(eval_splits)
        final_results = {}
        for split, sub_loader in zip(eval_splits, test_dataloader):
            results = evaluate_subset(sub_loader)
124
            final_results.update({f"{k}_{split}": v for k, v in results.items()})
125
126
127
        return final_results


128
129
130
131
132
133
134
135
136
137
def train_epoch(
    epoch: int,
    model: nn.Module,
    optimizer: Optimizer,
    _criterion: Callable,
    lr_scheduler: LRScheduler,
    train_dataloader: DataLoader,
    booster: Booster,
    coordinator: DistCoordinator,
):
138
    use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
139
    is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True)
140
    print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
141
142
    total_step = len(train_dataloader)

143
    model.train()
144
145
    optimizer.zero_grad()
    train_dataloader_iter = iter(train_dataloader)
146
    with tqdm(range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not print_flag) as pbar:
147
148
149
        # Forward pass
        for _ in pbar:
            if use_pipeline:
150
151
152
                outputs = booster.execute_pipeline(
                    train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
                )
153
                # Backward and optimize
154
                if is_pp_last_device:
155
156
                    loss = outputs["loss"]
                    pbar.set_postfix({"loss": loss.item()})
157
            else:
158
159
160
                data = next(train_dataloader_iter)
                data = move_to_cuda(data)
                outputs = model(**data)
161
162
163
                loss = _criterion(outputs, None)
                # Backward
                booster.backward(loss, optimizer)
164
                pbar.set_postfix({"loss": loss.item()})
165
166
167
168
169
170
171
172
173
174
175

            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
176
177
178
179
180
181
182
183
184
    parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run")
    parser.add_argument(
        "-p",
        "--plugin",
        type=str,
        default="torch_ddp",
        choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
        help="plugin to use",
    )
185
186
187
188
189
190
    parser.add_argument(
        "--model_type",
        type=str,
        default="bert",
        help="bert or albert",
    )
191
192
    parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
    parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
193
194
    args = parser.parse_args()

195
    if args.model_type == "bert":
196
        model_name = "bert-base-uncased"
197
    elif args.model_type == "albert":
198
199
200
        model_name = "albert-xxlarge-v2"
    else:
        raise RuntimeError
201

202
203
204
205
206
207
208
209
210
211
212
213
    # ==============================
    # Launch Distributed Environment
    # ==============================
    colossalai.launch_from_torch(config={}, seed=42)
    coordinator = DistCoordinator()

    lr = LEARNING_RATE * coordinator.world_size

    # ==============================
    # Instantiate Plugin and Booster
    # ==============================
    booster_kwargs = {}
214
215
216
    if args.plugin == "torch_ddp_fp16":
        booster_kwargs["mixed_precision"] = "fp16"
    if args.plugin.startswith("torch_ddp"):
217
        plugin = TorchDDPPlugin()
218
    elif args.plugin == "gemini":
219
        plugin = GeminiPlugin(initial_scale=2**5)
220
    elif args.plugin == "low_level_zero":
221
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
222
    elif args.plugin == "hybrid_parallel":
223
        # modify the param accordingly for finetuning test cases
224
225
226
227
        plugin = HybridParallelPlugin(
            tp_size=1,
            pp_size=2,
            num_microbatches=None,
228
229
230
            pp_style="interleaved",
            num_model_chunks=2,
            microbatch_size=16,
231
232
233
234
235
            enable_all_optimization=True,
            zero_stage=1,
            precision="fp16",
            initial_scale=1,
        )
236
237
238
239
240
241

    booster = Booster(plugin=plugin, **booster_kwargs)

    # ==============================
    # Prepare Dataloader
    # ==============================
242
243
244
    data_builder = GLUEDataBuilder(
        model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE
    )
245
246
247
248
249
250
251
252
253
    train_dataloader = data_builder.train_dataloader()
    test_dataloader = data_builder.test_dataloader()

    # ====================================
    # Prepare model, optimizer
    # ====================================
    # bert pretrained model

    cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
254

255
    if model_name == "bert-base-uncased":
256
        model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    elif model_name == "albert-xxlarge-v2":
        model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
    else:
        raise RuntimeError

    # optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

    # lr scheduler
    total_steps = len(train_dataloader) * NUM_EPOCHS
    num_warmup_steps = int(WARMUP_FRACTION * total_steps)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps,
    )

286
287
288
289
290
    def _criterion(outputs, inputs):
        outputs = output_transform_fn(outputs)
        loss = criterion(outputs)
        return loss

291
292
293
    # ==============================
    # Boost with ColossalAI
    # ==============================
294
295
296
    model, optimizer, _criterion, _, lr_scheduler = booster.boost(
        model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler
    )
297
298
299
300
301

    # ==============================
    # Train model
    # ==============================
    for epoch in range(NUM_EPOCHS):
302
        train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
303

304
305
306
307
308
309
310
311
312
313
    results = evaluate_model(
        model,
        _criterion,
        test_dataloader,
        data_builder.num_labels,
        args.task,
        data_builder.eval_splits,
        booster,
        coordinator,
    )
314
315
316

    if coordinator.is_master():
        print(results)
317
318
        if args.target_f1 is not None and "f1" in results:
            assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
319
320


321
if __name__ == "__main__":
322
    main()