benchmark_fsdp.py 4.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import functools
import os

import torch
import torch.distributed as dist
import tqdm
from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel

from colossalai.moe.manager import MOE_MANAGER


class RandomDataset(Dataset):
    def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
        self.num_samples = num_samples
        self.max_length = max_length
        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
        self.attention_mask = torch.ones_like(self.input_ids)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.input_ids[idx],
        }


def fsdp_main(rank, world_size, args):
    # initialize the process group

    # initialize the process group
    dist.init_process_group("nccl")

44
    MOE_MANAGER.setup(parallel=None)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    dp_size = dist.get_world_size()
    dataset = RandomDataset(
        max_length=args.seq_length,
        num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
    )
    sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
    train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
    train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
    torch.cuda.set_device(rank)

    config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
    set_openmoe_args(
        config,
        num_experts=config.num_experts,
        moe_layer_interval=config.moe_layer_interval,
        enable_load_balance=False,
        enable_kernel=False,
        enable_comm_overlap=False,
    )
    torch.set_default_dtype(torch.float16)
    model = OpenMoeForCausalLM(config)
    torch.set_default_dtype(torch.float32)
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            OpenMoeDecoderLayer,
        },
    )
    model = FSDP(
        model,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        auto_wrap_policy=auto_wrap_policy,
        device_id=torch.cuda.current_device(),
    )
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
    model.train()

    model_numel = get_model_numel(model)
    performance_evaluator = PerformanceEvaluator(
        model_numel,
        enable_grad_checkpoint=True,
        ignore_steps=args.warmup,
        dp_world_size=dist.get_world_size(),
    )

    for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
        performance_evaluator.on_step_start(step)
        input_ids, attention_mask, labels = (
            data["input_ids"].cuda(),
            data["attention_mask"].cuda(),
            data["labels"].cuda(),
        )

        optimizer.zero_grad()
        output = model(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
            chunk_head=False,
        )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        performance_evaluator.on_step_end(input_ids)

    performance_evaluator.on_fit_end()
    if dist.get_rank() == 0:
        print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="base",
        choices=["base", "8b"],
        help="base or 8b",
    )
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--seq_length", type=int, default=2048)
    parser.add_argument("--warmup", type=int, default=20)
    parser.add_argument("--active", type=int, default=20)
    args = parser.parse_args()

    torch.manual_seed(42)

    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    fsdp_main(local_rank, world_size, args)