main.py 7.5 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import subprocess
import mlperf_loadgen as lg
import argparse
import os
import math
import sys
from backend_PyTorch import get_SUT
from GPTJ_QDL import GPTJ_QDL
from GPTJ_QSL import get_GPTJ_QSL

# Function to parse the arguments passed during python file execution


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--backend", choices=["pytorch"], default="pytorch", help="Backend"
    )
    parser.add_argument(
        "--scenario",
        choices=["SingleStream", "Offline", "Server"],
        default="Offline",
        help="Scenario",
    )
    parser.add_argument("--model-path", default="EleutherAI/gpt-j-6B", help="")
    parser.add_argument(
        "--dataset-path",
        default="./data/cnn_eval.json",
        help="")
    parser.add_argument(
        "--accuracy",
        action="store_true",
        help="enable accuracy pass")
    parser.add_argument(
        "--dtype",
        default="float32",
        help="data type of the model, choose from float16, bfloat16 and float32",
    )
    parser.add_argument(
        "--quantized",
        action="store_true",
        help="use quantized model (only valid for onnxruntime backend)",
    )
    parser.add_argument(
        "--profile",
        action="store_true",
        help="enable profiling (only valid for onnxruntime backend)",
    )
    parser.add_argument(
        "--gpu", action="store_true", help="use GPU instead of CPU for the inference"
    )
    parser.add_argument(
        "--audit_conf",
        default="audit.conf",
        help="audit config for LoadGen settings during compliance runs",
    )
    parser.add_argument(
        "--user_conf",
        default="user.conf",
        help="user config for user LoadGen settings such as target QPS",
    )
    parser.add_argument(
        "--max_examples",
        type=int,
        default=13368,
        help="Maximum number of examples to consider (not limited by default)",
    )
    parser.add_argument(
        "--network",
        choices=["sut", "lon", None],
        default=None,
        help="Loadgen network mode",
    )
    parser.add_argument("--node", type=str, default="")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument(
        "--sut_server",
        nargs="*",
        default=["http://localhost:8000"],
        help="Address of the server(s) under test.",
    )
    args = parser.parse_args()
    return args


# Function to get the amount of temporary cache generated when running the GPT-J model
# Varies with the beam size set(Estimate: 6GB x Beam size)


def get_temp_cache():
    beam_size = int(os.environ.get("GPTJ_BEAM_SIZE", "4"))
    return 6 * beam_size


# Map the loadgen scenario as per the option given by the user
scenario_map = {
    "SingleStream": lg.TestScenario.SingleStream,
    "Offline": lg.TestScenario.Offline,
    "Server": lg.TestScenario.Server,
    "MultiStream": lg.TestScenario.MultiStream,
}

# Main function triggered when the script is run


def main():
    args = get_args()
    qsl = None
    if args.network != "sut":
        # Gets the Query Data Loader and Query Sample Loader
        # Responsible for loading(Query Sample Loader) and sending samples over
        # the network(Query Data Loader) to the server
        qsl = get_GPTJ_QSL(
            dataset_path=args.dataset_path, max_examples=args.max_examples
        )
        if args.network == "lon":
            qdl = GPTJ_QDL(
                sut_server_addr=args.sut_server, scenario=args.scenario, qsl=qsl
            )

        # Initiates and loads loadgen test settings and log path
        settings = lg.TestSettings()
        settings.scenario = scenario_map[args.scenario]
        # mlperf conf is automatically loaded by the loadgen
        # settings.FromConfig(args.mlperf_conf, "gptj", args.scenario)
        settings.FromConfig(args.user_conf, "gptj", args.scenario)

        # Chosing test mode Accutacy/Performance
        if args.accuracy:
            settings.mode = lg.TestMode.AccuracyOnly
        else:
            settings.mode = lg.TestMode.PerformanceOnly

        # Set log path
        log_path = os.environ.get("LOG_PATH")
        if not log_path:
            log_path = "build/logs"
        if not os.path.exists(log_path):
            os.makedirs(log_path)

        log_output_settings = lg.LogOutputSettings()
        log_output_settings.outdir = log_path
        log_output_settings.copy_summary_to_stdout = True
        log_settings = lg.LogSettings()
        log_settings.log_output = log_output_settings
        log_settings.enable_trace = True

    if args.network != "lon":
        # Gets SUT.
        # SUT only initialised when network is either None or is SUT as it is not needed in case of client(LON).
        # Only the server loads the model.
        sut = get_SUT(
            model_path=args.model_path,
            scenario=args.scenario,
            dtype=args.dtype,
            use_gpu=args.gpu,
            network=args.network,
            dataset_path=args.dataset_path,
            max_examples=args.max_examples,
            qsl=qsl,  # If args.network is None, then only QSL get passed to the SUT, else it will be None
        )

    if args.network == "lon" and args.scenario == "SingleStream":
        print(
            "ERROR: Single stream scenario in Loadgen Over the Network is not supported!"
        )

    # If network option is LON, QDL is loaded and request is served to SUT
    # based on the scenario given by user(Offline or SingleStream)
    elif args.network == "lon":
        lg.StartTestWithLogSettings(
            qdl.qdl, qsl.qsl, settings, log_settings, args.audit_conf
        )

    # If network option is SUT, a flask server is initiated, request is
    # processed and output is being sent back to LON client
    elif args.network == "sut":
        temp_cache = get_temp_cache()
        from network_SUT import app, node, set_backend, set_semaphore
        from systemStatus import get_cpu_memory_info, get_gpu_memory_info

        # Calculating free memory inorder to determine the number of instances that can be run at a particular time
        # lockVar contains the value of number of instances
        # Formula:
        #   Number of Instances = (free_memory_of_system - memory_size_taken_by_model)/temp_cache_size
        # Based on the got value(lockVar) a semaphore is initialised
        # Acquire and release of semaphore can be found in network_SUT.py
        model_mem_size = sut.total_mem_size
        if args.gpu:
            free_mem = int(
                os.environ.get(
                    "CM_CUDA_DEVICE_PROP_GLOBAL_MEMORY", get_gpu_memory_info()
                )
            ) / (1024**3)
        else:
            free_mem = get_cpu_memory_info()
        lockVar = math.floor((free_mem - model_mem_size) / temp_cache)
        node = args.node
        # Semaphore is set inorder to create multiple instances upon request
        # incomming
        set_semaphore(lockVar)
        print(f"Set the semaphore lock variable to {lockVar}")
        # Pass SUT as the backend
        set_backend(sut)
        app.run(debug=False, port=args.port, host="0.0.0.0")

    else:
        # Test not run in Loadgen Over the Network
        print("Running LoadGen test...")
        lg.StartTestWithLogSettings(
            sut.sut, qsl.qsl, settings, log_settings, args.audit_conf
        )

    print("Test Done!")

    if args.network != "lon":
        print("Destroying SUT...")
        lg.DestroySUT(sut.sut)

    if args.network != "sut":
        print("Destroying QSL...")
        lg.DestroyQSL(qsl.qsl)


if __name__ == "__main__":
    main()