finetune.py 10.8 KB
Newer Older
1
import argparse
2
from typing import Callable, List, Union
3
4
5
6
7
8

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
9
from torch.optim import Optimizer
10
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
11
12
13
14
15
16
17
18
19
20
21
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AlbertForSequenceClassification,
    AutoConfig,
    BertForSequenceClassification,
    get_linear_schedule_with_warmup,
)

import colossalai
from colossalai.booster import Booster
22
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
23
24
25
26
27
28
29
30
31
32
33
34
35
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1

36
37
38
output_transform_fn = lambda x: x
criterion = lambda x: x.loss

39
40
41
42
43
44

def move_to_cuda(batch):
    return {k: v.cuda() for k, v in batch.items()}


@torch.no_grad()
45
46
47
48
49
50
51
52
53
54
def evaluate_model(
    model: nn.Module,
    criterion,
    test_dataloader: Union[DataLoader, List[DataLoader]],
    num_labels: int,
    task_name: str,
    eval_splits: List[str],
    booster: Booster,
    coordinator: DistCoordinator,
):
55
56
57
58
    metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
    model.eval()

    def evaluate_subset(dataloader: DataLoader):
59
60
61
        use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
        is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()

62
63
64
65
        accum_loss = torch.zeros(1, device=get_current_device())
        for batch in dataloader:
            batch = move_to_cuda(batch)
            labels = batch["labels"]
66
            if use_pipeline:
67
68
69
70
71
                pg_mesh = booster.plugin.pg_mesh
                pp_group = booster.plugin.pp_group
                current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
                current_rank = dist.get_rank()
                batch = iter([batch])
72
                outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
73

74
                if is_pp_last_stage:
75
                    logits = outputs["outputs"]["logits"]
76
                    val_loss = outputs["loss"]
77
78
79
80
81
82
83
                    accum_loss.add_(val_loss)

                    if num_labels > 1:
                        preds = torch.argmax(logits, axis=1)
                    elif num_labels == 1:
                        preds = logits.squeeze()

84
                    dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)
85
86
87

                    metric.add_batch(predictions=preds, references=labels)
                elif current_rank in current_pp_group_ranks:
88
89
                    object_list = [None, None]
                    dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
90

91
92
                    metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
                    accum_loss.add_(object_list[1].to(get_current_device()))
93
94
95
96
97
98
99
100
101
102
103
104
105

            else:
                batch = move_to_cuda(batch)
                outputs = model(**batch)
                val_loss, logits = outputs[:2]
                accum_loss.add_(val_loss)

                if num_labels > 1:
                    preds = torch.argmax(logits, axis=1)
                elif num_labels == 1:
                    preds = logits.squeeze()

                metric.add_batch(predictions=preds, references=labels)
106
107
108

        results = metric.compute()
        dist.all_reduce(accum_loss.div_(len(dataloader)))
109
        if coordinator.is_master() and results is not None:
110
            results["loss"] = accum_loss.item() / coordinator.world_size
111

112
113
114
115
116
117
118
119
120
        return results

    if isinstance(test_dataloader, DataLoader):
        return evaluate_subset(test_dataloader)
    else:
        assert len(test_dataloader) == len(eval_splits)
        final_results = {}
        for split, sub_loader in zip(eval_splits, test_dataloader):
            results = evaluate_subset(sub_loader)
121
            final_results.update({f"{k}_{split}": v for k, v in results.items()})
122
123
124
        return final_results


125
126
127
128
129
130
131
132
133
134
def train_epoch(
    epoch: int,
    model: nn.Module,
    optimizer: Optimizer,
    _criterion: Callable,
    lr_scheduler: LRScheduler,
    train_dataloader: DataLoader,
    booster: Booster,
    coordinator: DistCoordinator,
):
135
136
    use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
    is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
137
    print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
138
139
    total_step = len(train_dataloader)

140
    model.train()
141
142
    optimizer.zero_grad()
    train_dataloader_iter = iter(train_dataloader)
143
    with tqdm(range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not print_flag) as pbar:
144
145
146
        # Forward pass
        for _ in pbar:
            if use_pipeline:
147
148
149
                outputs = booster.execute_pipeline(
                    train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
                )
150
                # Backward and optimize
151
                if is_pp_last_stage:
152
153
                    loss = outputs["loss"]
                    pbar.set_postfix({"loss": loss.item()})
154
            else:
155
156
157
                data = next(train_dataloader_iter)
                data = move_to_cuda(data)
                outputs = model(**data)
158
159
160
                loss = _criterion(outputs, None)
                # Backward
                booster.backward(loss, optimizer)
161
                pbar.set_postfix({"loss": loss.item()})
162
163
164
165
166
167
168
169
170
171
172

            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
173
174
175
176
177
178
179
180
181
    parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run")
    parser.add_argument(
        "-p",
        "--plugin",
        type=str,
        default="torch_ddp",
        choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
        help="plugin to use",
    )
182
183
184
185
186
187
    parser.add_argument(
        "--model_type",
        type=str,
        default="bert",
        help="bert or albert",
    )
188
189
    parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
    parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
190
191
    args = parser.parse_args()

192
    if args.model_type == "bert":
193
        model_name = "bert-base-uncased"
194
    elif args.model_type == "albert":
195
196
197
        model_name = "albert-xxlarge-v2"
    else:
        raise RuntimeError
198

199
200
201
202
203
204
205
206
207
208
209
210
    # ==============================
    # Launch Distributed Environment
    # ==============================
    colossalai.launch_from_torch(config={}, seed=42)
    coordinator = DistCoordinator()

    lr = LEARNING_RATE * coordinator.world_size

    # ==============================
    # Instantiate Plugin and Booster
    # ==============================
    booster_kwargs = {}
211
212
213
    if args.plugin == "torch_ddp_fp16":
        booster_kwargs["mixed_precision"] = "fp16"
    if args.plugin.startswith("torch_ddp"):
214
        plugin = TorchDDPPlugin()
215
    elif args.plugin == "gemini":
216
        plugin = GeminiPlugin(initial_scale=2**5)
217
    elif args.plugin == "low_level_zero":
218
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
219
    elif args.plugin == "hybrid_parallel":
220
        # modify the param accordingly for finetuning test cases
221
222
223
224
225
226
227
228
229
230
        plugin = HybridParallelPlugin(
            tp_size=1,
            pp_size=2,
            num_microbatches=None,
            microbatch_size=1,
            enable_all_optimization=True,
            zero_stage=1,
            precision="fp16",
            initial_scale=1,
        )
231
232
233
234
235
236

    booster = Booster(plugin=plugin, **booster_kwargs)

    # ==============================
    # Prepare Dataloader
    # ==============================
237
238
239
    data_builder = GLUEDataBuilder(
        model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE
    )
240
241
242
243
244
245
246
247
248
    train_dataloader = data_builder.train_dataloader()
    test_dataloader = data_builder.test_dataloader()

    # ====================================
    # Prepare model, optimizer
    # ====================================
    # bert pretrained model

    cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
249

250
    if model_name == "bert-base-uncased":
251
        model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    elif model_name == "albert-xxlarge-v2":
        model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
    else:
        raise RuntimeError

    # optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

    # lr scheduler
    total_steps = len(train_dataloader) * NUM_EPOCHS
    num_warmup_steps = int(WARMUP_FRACTION * total_steps)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps,
    )

281
282
283
284
285
    def _criterion(outputs, inputs):
        outputs = output_transform_fn(outputs)
        loss = criterion(outputs)
        return loss

286
287
288
    # ==============================
    # Boost with ColossalAI
    # ==============================
289
290
291
    model, optimizer, _criterion, _, lr_scheduler = booster.boost(
        model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler
    )
292
293
294
295
296

    # ==============================
    # Train model
    # ==============================
    for epoch in range(NUM_EPOCHS):
297
        train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
298

299
300
301
302
303
304
305
306
307
308
    results = evaluate_model(
        model,
        _criterion,
        test_dataloader,
        data_builder.num_labels,
        args.task,
        data_builder.eval_splits,
        booster,
        coordinator,
    )
309
310
311

    if coordinator.is_master():
        print(results)
312
313
        if args.target_f1 is not None and "f1" in results:
            assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
314
315


316
if __name__ == "__main__":
317
    main()