finetune.py 11.1 KB
Newer Older
1
import argparse
2
from typing import Callable, List, Union
3
4
5
6
7
8

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
9
from torch.optim import Optimizer
10
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
11
12
13
14
15
16
17
18
19
20
21
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AlbertForSequenceClassification,
    AutoConfig,
    BertForSequenceClassification,
    get_linear_schedule_with_warmup,
)

import colossalai
from colossalai.booster import Booster
22
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
23
24
25
26
27
28
29
30
31
32
33
34
35
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1

36
37
38
output_transform_fn = lambda x: x
criterion = lambda x: x.loss

39
40
41
42
43
44

def move_to_cuda(batch):
    return {k: v.cuda() for k, v in batch.items()}


@torch.no_grad()
45
46
47
48
49
50
51
52
53
54
def evaluate_model(
    model: nn.Module,
    criterion,
    test_dataloader: Union[DataLoader, List[DataLoader]],
    num_labels: int,
    task_name: str,
    eval_splits: List[str],
    booster: Booster,
    coordinator: DistCoordinator,
):
55
56
57
58
    metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
    model.eval()

    def evaluate_subset(dataloader: DataLoader):
59
        use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
60
61
62
        is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
            None if not booster.plugin.stage_manager.is_interleave else -1
        )
63

64
65
66
67
        accum_loss = torch.zeros(1, device=get_current_device())
        for batch in dataloader:
            batch = move_to_cuda(batch)
            labels = batch["labels"]
68
            if use_pipeline:
69
70
71
72
73
                pg_mesh = booster.plugin.pg_mesh
                pp_group = booster.plugin.pp_group
                current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
                current_rank = dist.get_rank()
                batch = iter([batch])
74

75
                outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
76

77
                if is_pp_last_device:
78
                    logits = outputs["outputs"]["logits"]
79
                    val_loss = outputs["loss"]
80
81
82
83
84
85
86
                    accum_loss.add_(val_loss)

                    if num_labels > 1:
                        preds = torch.argmax(logits, axis=1)
                    elif num_labels == 1:
                        preds = logits.squeeze()

87
                    dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)
88
89
90

                    metric.add_batch(predictions=preds, references=labels)
                elif current_rank in current_pp_group_ranks:
91
92
                    object_list = [None, None]
                    dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
93

94
95
                    metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
                    accum_loss.add_(object_list[1].to(get_current_device()))
96
97
98
99
100
101
102
103
104
105
106
107
108

            else:
                batch = move_to_cuda(batch)
                outputs = model(**batch)
                val_loss, logits = outputs[:2]
                accum_loss.add_(val_loss)

                if num_labels > 1:
                    preds = torch.argmax(logits, axis=1)
                elif num_labels == 1:
                    preds = logits.squeeze()

                metric.add_batch(predictions=preds, references=labels)
109
110
111

        results = metric.compute()
        dist.all_reduce(accum_loss.div_(len(dataloader)))
112
        if coordinator.is_master() and results is not None:
113
            results["loss"] = accum_loss.item() / coordinator.world_size
114

115
116
117
118
119
120
121
122
123
        return results

    if isinstance(test_dataloader, DataLoader):
        return evaluate_subset(test_dataloader)
    else:
        assert len(test_dataloader) == len(eval_splits)
        final_results = {}
        for split, sub_loader in zip(eval_splits, test_dataloader):
            results = evaluate_subset(sub_loader)
124
            final_results.update({f"{k}_{split}": v for k, v in results.items()})
125
126
127
        return final_results


128
129
130
131
132
133
134
135
136
137
def train_epoch(
    epoch: int,
    model: nn.Module,
    optimizer: Optimizer,
    _criterion: Callable,
    lr_scheduler: LRScheduler,
    train_dataloader: DataLoader,
    booster: Booster,
    coordinator: DistCoordinator,
):
138
    use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
139
140
141
142
    is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
        None if not booster.plugin.stage_manager.is_interleave else -1
    )
    print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
143
144
    total_step = len(train_dataloader)

145
    model.train()
146
147
    optimizer.zero_grad()
    train_dataloader_iter = iter(train_dataloader)
148
    with tqdm(range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not print_flag) as pbar:
149
150
151
        # Forward pass
        for _ in pbar:
            if use_pipeline:
152
153
154
                outputs = booster.execute_pipeline(
                    train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
                )
155
                # Backward and optimize
156
                if is_pp_last_device:
157
158
                    loss = outputs["loss"]
                    pbar.set_postfix({"loss": loss.item()})
159
            else:
160
161
162
                data = next(train_dataloader_iter)
                data = move_to_cuda(data)
                outputs = model(**data)
163
164
165
                loss = _criterion(outputs, None)
                # Backward
                booster.backward(loss, optimizer)
166
                pbar.set_postfix({"loss": loss.item()})
167
168
169
170
171
172
173
174
175
176
177

            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
178
179
180
181
182
183
184
185
186
    parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run")
    parser.add_argument(
        "-p",
        "--plugin",
        type=str,
        default="torch_ddp",
        choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
        help="plugin to use",
    )
187
188
189
190
191
192
    parser.add_argument(
        "--model_type",
        type=str,
        default="bert",
        help="bert or albert",
    )
193
194
    parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
    parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
195
196
    args = parser.parse_args()

197
    if args.model_type == "bert":
198
        model_name = "bert-base-uncased"
199
    elif args.model_type == "albert":
200
201
202
        model_name = "albert-xxlarge-v2"
    else:
        raise RuntimeError
203

204
205
206
207
208
209
210
211
212
213
214
215
    # ==============================
    # Launch Distributed Environment
    # ==============================
    colossalai.launch_from_torch(config={}, seed=42)
    coordinator = DistCoordinator()

    lr = LEARNING_RATE * coordinator.world_size

    # ==============================
    # Instantiate Plugin and Booster
    # ==============================
    booster_kwargs = {}
216
217
218
    if args.plugin == "torch_ddp_fp16":
        booster_kwargs["mixed_precision"] = "fp16"
    if args.plugin.startswith("torch_ddp"):
219
        plugin = TorchDDPPlugin()
220
    elif args.plugin == "gemini":
221
        plugin = GeminiPlugin(initial_scale=2**5)
222
    elif args.plugin == "low_level_zero":
223
        plugin = LowLevelZeroPlugin(initial_scale=2**5)
224
    elif args.plugin == "hybrid_parallel":
225
        # modify the param accordingly for finetuning test cases
226
227
228
229
        plugin = HybridParallelPlugin(
            tp_size=1,
            pp_size=2,
            num_microbatches=None,
230
231
232
            pp_style="interleaved",
            num_model_chunks=2,
            microbatch_size=16,
233
234
235
236
237
            enable_all_optimization=True,
            zero_stage=1,
            precision="fp16",
            initial_scale=1,
        )
238
239
240
241
242
243

    booster = Booster(plugin=plugin, **booster_kwargs)

    # ==============================
    # Prepare Dataloader
    # ==============================
244
245
246
    data_builder = GLUEDataBuilder(
        model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE
    )
247
248
249
250
251
252
253
254
255
    train_dataloader = data_builder.train_dataloader()
    test_dataloader = data_builder.test_dataloader()

    # ====================================
    # Prepare model, optimizer
    # ====================================
    # bert pretrained model

    cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
256

257
    if model_name == "bert-base-uncased":
258
        model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    elif model_name == "albert-xxlarge-v2":
        model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
    else:
        raise RuntimeError

    # optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

    # lr scheduler
    total_steps = len(train_dataloader) * NUM_EPOCHS
    num_warmup_steps = int(WARMUP_FRACTION * total_steps)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps,
    )

288
289
290
291
292
    def _criterion(outputs, inputs):
        outputs = output_transform_fn(outputs)
        loss = criterion(outputs)
        return loss

293
294
295
    # ==============================
    # Boost with ColossalAI
    # ==============================
296
297
298
    model, optimizer, _criterion, _, lr_scheduler = booster.boost(
        model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler
    )
299
300
301
302
303

    # ==============================
    # Train model
    # ==============================
    for epoch in range(NUM_EPOCHS):
304
        train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
305

306
307
308
309
310
311
312
313
314
315
    results = evaluate_model(
        model,
        _criterion,
        test_dataloader,
        data_builder.num_labels,
        args.task,
        data_builder.eval_splits,
        booster,
        coordinator,
    )
316
317
318

    if coordinator.is_master():
        print(results)
319
320
        if args.target_f1 is not None and "f1" in results:
            assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
321
322


323
if __name__ == "__main__":
324
    main()