finetune.py 6.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import argparse
from typing import List, Union

import datasets
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

# ==============================
# Prepare Hyperparameters
# ==============================
24
NUM_EPOCHS = 1
25
26
27
28
29
30
31
32
33
34
35
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1


def move_to_cuda(batch):
    return {k: v.cuda() for k, v in batch.items()}


@torch.no_grad()
36
37
38
39
40
41
42
43
def evaluate(
    model: nn.Module,
    test_dataloader: Union[DataLoader, List[DataLoader]],
    num_labels: int,
    task_name: str,
    eval_splits: List[str],
    coordinator: DistCoordinator,
):
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    metric = datasets.load_metric("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
    model.eval()

    def evaluate_subset(dataloader: DataLoader):
        accum_loss = torch.zeros(1, device=get_current_device())
        for batch in dataloader:
            batch = move_to_cuda(batch)
            outputs = model(**batch)
            val_loss, logits = outputs[:2]
            accum_loss.add_(val_loss)

            if num_labels > 1:
                preds = torch.argmax(logits, axis=1)
            elif num_labels == 1:
                preds = logits.squeeze()

            labels = batch["labels"]

            metric.add_batch(predictions=preds, references=labels)

        results = metric.compute()
        dist.all_reduce(accum_loss.div_(len(dataloader)))
        if coordinator.is_master():
67
            results["loss"] = accum_loss.item() / coordinator.world_size
68
69
70
71
72
73
74
75
76
        return results

    if isinstance(test_dataloader, DataLoader):
        return evaluate_subset(test_dataloader)
    else:
        assert len(test_dataloader) == len(eval_splits)
        final_results = {}
        for split, sub_loader in zip(eval_splits, test_dataloader):
            results = evaluate_subset(sub_loader)
77
            final_results.update({f"{k}_{split}": v for k, v in results.items()})
78
79
80
        return final_results


81
82
83
84
85
86
87
88
89
def train_epoch(
    epoch: int,
    model: nn.Module,
    optimizer: Optimizer,
    lr_scheduler,
    train_dataloader: DataLoader,
    booster: Booster,
    coordinator: DistCoordinator,
):
90
    model.train()
91
    with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar:
92
93
94
95
96
97
98
99
100
101
102
103
104
        for batch in pbar:
            # Forward pass
            batch = move_to_cuda(batch)
            outputs = model(**batch)
            loss = outputs[0]

            # Backward and optimize
            booster.backward(loss, optimizer)
            optimizer.step()
            optimizer.zero_grad()
            lr_scheduler.step()

            # Print log info
105
            pbar.set_postfix({"loss": loss.item()})
106
107
108
109
110
111
112


def main():
    # ==============================
    # Parse Arguments
    # ==============================
    parser = argparse.ArgumentParser()
113
114
115
116
117
118
119
120
121
122
    parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run")
    parser.add_argument(
        "-p",
        "--plugin",
        type=str,
        default="torch_ddp",
        choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"],
        help="plugin to use",
    )
    parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
123
124
125
126
127
128
129
130
131
132
    args = parser.parse_args()

    # ==============================
    # Launch Distributed Environment
    # ==============================
    colossalai.launch_from_torch(config={}, seed=42)
    coordinator = DistCoordinator()

    # local_batch_size = BATCH_SIZE // coordinator.world_size
    lr = LEARNING_RATE * coordinator.world_size
133
    model_name = "bert-base-uncased"
134
135
136
137
138

    # ==============================
    # Instantiate Plugin and Booster
    # ==============================
    booster_kwargs = {}
139
140
141
    if args.plugin == "torch_ddp_fp16":
        booster_kwargs["mixed_precision"] = "fp16"
    if args.plugin.startswith("torch_ddp"):
142
        plugin = TorchDDPPlugin()
143
    elif args.plugin == "gemini":
144
        plugin = GeminiPlugin(placement_policy="static", strict_ddp_mode=True, initial_scale=2**5)
145
    elif args.plugin == "low_level_zero":
146
147
148
149
150
151
152
        plugin = LowLevelZeroPlugin(initial_scale=2**5)

    booster = Booster(plugin=plugin, **booster_kwargs)

    # ==============================
    # Prepare Dataloader
    # ==============================
153
154
155
    data_builder = GLUEDataBuilder(
        model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE
    )
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    train_dataloader = data_builder.train_dataloader()
    test_dataloader = data_builder.test_dataloader()

    # ====================================
    # Prepare model, optimizer
    # ====================================
    # bert pretrained model
    config = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
    model = BertForSequenceClassification.from_pretrained(model_name, config=config)

    # optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

    # lr scheduler
    total_steps = len(train_dataloader) * NUM_EPOCHS
    num_warmup_steps = int(WARMUP_FRACTION * total_steps)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps,
    )

    # ==============================
    # Boost with ColossalAI
    # ==============================
    model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)

    # ==============================
    # Train model
    # ==============================
    for epoch in range(NUM_EPOCHS):
        train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)

201
202
203
    results = evaluate(
        model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator
    )
204
205
206

    if coordinator.is_master():
        print(results)
207
208
        if args.target_f1 is not None and "f1" in results:
            assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
209
210


211
if __name__ == "__main__":
212
    main()