infer.py 4.36 KB
Newer Older
Xuanlei Zhao's avatar
Xuanlei Zhao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse

import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_model
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device


def parse_args():
    # basic settings
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="mistralai/Mixtral-8x7B-v0.1",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--plugin",
        type=str,
        default="hybrid",
        choices=["ep"],
        help="Parallel methos.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="./outputs",
        help="The path of your saved model after finetuning.",
    )
    parser.add_argument(
        "--precision",
        type=str,
        default="bf16",
        choices=["fp32", "bf16", "fp16"],
        help="The mixed precision training.",
    )
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")

    # kernel
    parser.add_argument(
        "--use_kernel",
        action="store_true",
        help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
    )
    parser.add_argument(
        "--use_layernorm_kernel",
        action="store_true",
        help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    # Launch ColossalAI
    colossalai.launch_from_torch(config={}, seed=args.seed)
    coordinator = DistCoordinator()

    # Set plugin
    booster_kwargs = {}
    hybrid_dict = {
        "tp_size": 1,
        "custom_policy": MixtralForCausalLMPolicy(),
        "enable_fused_normalization": args.use_layernorm_kernel,
        "enable_jit_fused": args.use_kernel,
        "precision": args.precision,
        "checkpoint_io": MixtralMoECheckpointIO,
        "zero_stage": 1,
    }
    mgr_dict = {}
    if args.plugin == "ep":
        dp_size = dist.get_world_size()
        plugin = MoeHybridParallelPlugin(
            pp_size=1,
            **hybrid_dict,
        )
        MOE_MANAGER.setup(
            parallel="EP",
            max_ep_size=dp_size,
            **mgr_dict,
        )
    else:
        raise ValueError(f"Invalid plugin {args.plugin}")
    coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")

    # Build mixtral model
    config = MixtralConfig.from_pretrained(args.model_name)
    config.num_local_experts = 1  # dont change this. it will not affect model
    with skip_init():
        model = MixtralForCausalLM(config)
    model.num_experts = 8
    model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
    model = model.to(get_current_device())
    coordinator.print_on_master(f"Finish init model with config:\n{config}")

    # Replace moe
    with skip_init():
        replace_moe_layer(model)
    model.eval()
    coordinator.print_on_master(f"Finish replace moe module")

    # Prepare tokenizer and dataloader
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # Set booster
    booster = Booster(plugin=plugin, **booster_kwargs)
    model, _, _, _, _ = booster.boost(model=model)
    coordinator.print_on_master(f"Finish init booster")

    # load ckpt
    load_model(args.model_name, model, booster)
    coordinator.print_on_master(f"Finish load ckpt")

    text = ["Hello my name is", "1+1=?"]
    tokenizer.pad_token = tokenizer.unk_token
    inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
    outputs = model.module.generate(**inputs, max_new_tokens=20)
    outputs = tokenizer.batch_decode(outputs)[0]
    print(outputs)


if __name__ == "__main__":
    main()