moe_tune_runner.py 3.15 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import os
from pathlib import Path
import argparse

import aiter
from aiter.jit.core import AITER_ROOT_DIR
import pandas as pd
from moe_tuner import MoeTuner
from moe_problem import MoeQuantType, get_dtype, get_QuantType
from aiter.fused_moe_ck import ck_tuned_file



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--tuned_file",
        type=str,
        default=os.getenv("MOE_TUNE_OUTPUT", ck_tuned_file),
        help="output file for tuned moe solutions",
    )
    parser.add_argument(
        "--input_file",
        type=str,
        default=os.getenv("MOE_TUNE_INPUT", None),
        help="list of moe problems to tune for, mutually exclusive with model_dir",
    )
    parser.add_argument(
        "--mp",
        type=int,
        default=torch.cuda.device_count(),
        help="Tuning on multiple GPUs using multiple processes",
    )
    parser.add_argument(
        "--quant_type",
        type=str,
        default=MoeQuantType.NO_QUANT,
        help="quantization type: no_quant int4_w4a16 int4_w4a8 int8_w8a8_block...",
    )
    parser.add_argument(
        "--indtype",
        type=str,
        default="f16",
        choices=["f16", "bf16", "int8"],
        help="dtype: f16 bf16 int8. Use this to override the"
        " input_file or if no input_file provided",
    )
    parser.add_argument(
        "--tokens",
        type=int,
        default=16,
        help="Number of tokens to process",
    )
    parser.add_argument(
        "--inter_size",
        type=int,
        default=256,
        help="Inter size of the MLP",
    )
    parser.add_argument(
        "--hidden_size",
        type=int,
        default=7168,
        help="Hidden size of the MLP",
    )
    parser.add_argument(
        "--experts",
        type=int,
        default=256,
        help="Number of experts",
    )
    parser.add_argument(
        "--topk",
        type=int,
        default=8,
        help="TopK experts to use",
    )
    args = parser.parse_args()
    
    indtype = get_dtype(args.indtype)
    quantType = get_QuantType(args.quant_type)
    moe_tuner = MoeTuner(indtype, args.tuned_file, mp=args.mp)
    if args.input_file:
        print(f">>>Info: Loading {args.input_file}")
        if not Path(args.input_file).is_file():
            print(f">>> ERROR: {args.input_file} does not exist.  Exiting")
            exit(1)

        shapes = pd.read_csv(args.input_file).fillna("")
        for i in range(len(shapes)):
            ds = shapes.iloc[i]
            moe_tuner.add_moe(
                get_QuantType(ds["quant_type"]),
                get_dtype(ds["indtype"]),
                ds["token"],
                ds["inter_dim"],
                ds["model_dim"],
                ds["expert"],
                ds["topk"],
                ds["q_size_n"],
                ds["q_size_k"]
            )
    
    else:
        print(">>>Info: No input_file provided, tuning a single shape")
        moe_tuner.add_moe(
            quantType,
            indtype,
            args.tokens,
            args.inter_size,
            args.hidden_size,
            args.experts,
            args.topk
        )

    moe_tuner.find_best_sols()