test_inference.py 17.5 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

import os
import time
import torch
import pytest
import itertools
import deepspeed
from deepspeed.git_version_info import torch_info
from unit.common import DistributedTest
from packaging import version as pkg_version
from deepspeed.ops.op_builder import OpBuilder
from transformers import pipeline
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.roberta.modeling_roberta import RobertaLayer
from huggingface_hub import HfApi
from deepspeed.model_implementations import DeepSpeedTransformerInference
from torch import nn
from deepspeed.accelerator import get_accelerator

rocm_version = OpBuilder.installed_rocm_version()
if rocm_version != (0, 0):
    pytest.skip("skip inference tests on rocm for now", allow_module_level=True)

_bert_models = [
    "bert-base-cased",
    "bert-base-uncased",
    "bert-large-cased",
    "bert-large-uncased",
    "bert-base-multilingual-cased",
    "bert-base-multilingual-uncased",
    "deepset/minilm-uncased-squad2",
    "cross-encoder/ms-marco-MiniLM-L-12-v2",
    "dslim/bert-base-NER",
    "bert-large-uncased-whole-word-masking-finetuned-squad",
    "distilbert-base-cased-distilled-squad",
]
_roberta_models = [
    "roberta-large",
    "roberta-base",
    "deepset/roberta-base-squad2",
    "j-hartmann/emotion-english-distilroberta-base",
    "Jean-Baptiste/roberta-large-ner-english",
]
_gpt_models = [
    "gpt2",
    "distilgpt2",
    "Norod78/hebrew-bad_wiki-gpt_neo-tiny",
    #"EleutherAI/gpt-j-6B", # Removed as this is causing OOM errors randomly
    "bigscience/bloom-560m",
]
_opt_models = [
aiss's avatar
aiss committed
56
57
    "facebook/opt-125m",  # 125m, 1.7B, ..., 175B variants have the same model architecture.
    "facebook/opt-350m",  # 350m applies layer norm after attnention layer which is different than other variants.
aiss's avatar
aiss committed
58
59
60
61
62
]
_all_models = HfApi().list_models()

test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
test_tasks = [
aiss's avatar
aiss committed
63
64
    "fill-mask", "question-answering", "text-classification", "token-classification", "text-generation",
    "text2text-generation", "summarization", "translation"
aiss's avatar
aiss committed
65
]
aiss's avatar
aiss committed
66
pytest.all_models = {task: [m.modelId for m in _all_models if m.pipeline_tag == task] for task in test_tasks}
aiss's avatar
aiss committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

_model_w_tasks = itertools.product(*[test_models, test_tasks])


def _valid_model_task(model_task):
    m, t = model_task
    return m in pytest.all_models[t]


pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks))
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
"""
These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph
"""


@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
    return request.param


@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
    return request.param


@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
    return request.param


"""
This fixture will validate the configuration
"""


@pytest.fixture()
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
    model, task = model_w_task
    msg = ""
    if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
        msg = "DS inference injection doesn't work well on older torch versions"
    elif model not in pytest.all_models[task]:
        msg = f"Not a valid model / task combination: {model} / {task}"
    elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
        msg = "CUDA not detected, cannot use CUDA Graph"
aiss's avatar
aiss committed
113
    elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
aiss's avatar
aiss committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        msg = "CUDA Graph is only available in torch versions >= 1.10"
    elif "gpt-j-6B" in model:
        if dtype != torch.half:
            msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
        elif enable_cuda_graph:
            msg = f"Not enough GPU memory to run {model} with CUDA Graph enabled"
    elif "gpt-neox-20b" in model:  # TODO: remove this when neox issues resolved
        msg = "Skipping gpt-neox-20b for now"
    elif ("gpt-neox-20b" in model) and (dtype != torch.half):
        msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
    elif ("bloom" in model) and (dtype != torch.half):
        msg = f"Bloom models only support half precision, cannot use dtype {dtype}"
    elif ("bert" not in model.lower()) and enable_cuda_graph:
        msg = "Non bert/roberta models do no support CUDA Graph"
    return msg


"""
These fixtures can be used to customize the query, inference args, and assert
statement for each combination of model /task
"""


@pytest.fixture
def query(model_w_task):
    model, task = model_w_task
aiss's avatar
aiss committed
140
    angle_bracket_mask_models = ["roberta", "camembert", "esm", "ibert", "luke", "mpnet", "yoso", "mpnet"]
aiss's avatar
aiss committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    if task == "fill-mask":
        if any(map(lambda x: x in model, angle_bracket_mask_models)):
            return "Hello I'm a <mask> model."
        else:
            return "Hell I'm a [MASK] model."
    elif task == "question-answering":
        return {
            "question": "What's my name?",
            "context": "My name is Clara and I live in Berkeley",
        }
    elif task == "text-classification":
        return "DeepSpeed is the greatest"
    elif task == "token-classification":
        return "My name is jean-baptiste and I live in montreal."
    elif task == "text-generation":
        return "DeepSpeed is the greatest"
    elif task == "text2text-generation":
        return "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
    elif task == "translation" or task == "summarization":
        return "Hello, my dog is cute"
    else:
        NotImplementedError(f'query for task "{task}" is not implemented')


@pytest.fixture
def inf_kwargs(model_w_task):
    model, task = model_w_task
    if task == "text-generation":
        if model == "EleutherAI/gpt-j-6B":
            # This model on V100 is hitting memory problems that limit the number of output tokens
            return {"do_sample": False, "max_length": 12}
        return {"do_sample": False, "max_length": 20}
    else:
        return {}


def fill_mask_assert(x, y):
    return set(res["token_str"] for res in x) == set(res["token_str"] for res in y)


def question_answering_assert(x, y):
    return x["answer"] == y["answer"]


def text_classification_assert(x, y):
    return set(res["label"] for res in x) == set(res["label"] for res in y)


def token_classification_assert(x, y):
    return set(ent["word"] for ent in x) == set(ent["word"] for ent in y)


def text_generation_assert(x, y):
aiss's avatar
aiss committed
195
    return set(res["generated_text"] for res in x) == set(res["generated_text"] for res in y)
aiss's avatar
aiss committed
196
197
198


def text2text_generation_assert(x, y):
aiss's avatar
aiss committed
199
    return set(res["generated_text"] for res in x) == set(res["generated_text"] for res in y)
aiss's avatar
aiss committed
200
201
202


def translation_assert(x, y):
aiss's avatar
aiss committed
203
    return set(res["translation_text"] for res in x) == set(res["translation_text"] for res in y)
aiss's avatar
aiss committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229


def summarization_assert(x, y):
    return set(res["summary_text"] for res in x) == set(res["summary_text"] for res in y)


@pytest.fixture
def assert_fn(model_w_task):
    model, task = model_w_task
    assert_fn_dict = {
        "fill-mask": fill_mask_assert,
        "question-answering": question_answering_assert,
        "text-classification": text_classification_assert,
        "token-classification": token_classification_assert,
        "text-generation": text_generation_assert,
        "text2text-generation": text2text_generation_assert,
        "translation": translation_assert,
        "summarization": summarization_assert
    }
    assert_fn = assert_fn_dict.get(task, None)
    if assert_fn is None:
        NotImplementedError(f'assert_fn for task "{task}" is not implemented')
    return assert_fn


def check_injection(model):
aiss's avatar
aiss committed
230

aiss's avatar
aiss committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def verify_injection(module):
        for child in module.children():
            if isinstance(child, nn.ModuleList):
                assert isinstance(child[0], DeepSpeedTransformerInference),\
                    "DeepSpeed-Inference Transformer kernels has not been injected in the model"
                break
            else:
                verify_injection(child)

    verify_injection(model)


"""
Tests
"""


@pytest.mark.inference
class TestModelTask(DistributedTest):
    world_size = 1

    def test(
        self,
        model_w_task,
        dtype,
        enable_cuda_graph,
        query,
        inf_kwargs,
        assert_fn,
        invalid_model_task_config,
    ):
        if invalid_model_task_config:
            pytest.skip(invalid_model_task_config)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))

        # Load the model on CPU first to avoid OOM for large models @fp32
        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        if dtype == torch.half:
            pipe.model.half()

        # Switch device to GPU after converting to half
        device = torch.device(get_accelerator().device_name(local_rank))
        pipe.device = device
        pipe.model.to(device)

        # Warm-up queries for perf measurement
        #for i in range(10):
        #    _ = pipe(query, **inf_kwargs)
        get_accelerator().synchronize()
        start = time.time()
        bs_output = pipe(query, **inf_kwargs)
        get_accelerator().synchronize()
        bs_time = time.time() - start

        pipe.model = deepspeed.init_inference(
            pipe.model,
            mp_size=1,
            dtype=dtype,
            replace_with_kernel_inject=True,
            enable_cuda_graph=enable_cuda_graph,
        )
        check_injection(pipe.model)
        # Warm-up queries for perf measurement
        #for i in range(10):
        #    _ = pipe(query, **inf_kwargs)
        get_accelerator().synchronize()
        start = time.time()
        ds_output = pipe(query, **inf_kwargs)
        get_accelerator().synchronize()
        ds_time = time.time() - start

        # facebook/opt* and some bigscient/bloom* models are not matching
        # baseline exactly, adding an exception to them for now
        if ("opt" in model) or ("bloom" in model):
            bs_output = pipe(query, **inf_kwargs)

        # These performance tests are only measuring the time for a single
        # inference request, we just want to check that performance isn't terrible
        #assert ds_time <= (bs_time * 1.1)
        assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
aiss's avatar
aiss committed
316
317
318
319
320
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
                                          ("EleutherAI/gpt-neox-20b", "text-generation"),
                                          ("bigscience/bloom-3b", "text-generation"),
                                          ("EleutherAI/gpt-j-6B", "text-generation")],
                         ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"])
aiss's avatar
aiss committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
class TestMPSize(DistributedTest):
    world_size = 4

    def test(
        self,
        model_w_task,
        dtype,
        query,
        inf_kwargs,
        assert_fn,
        invalid_model_task_config,
    ):
        if invalid_model_task_config:
            pytest.skip(invalid_model_task_config)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))

        # We have to load these large models on CPU with pipeline because not
        # enough GPU memory
        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        bs_output = pipe(query, **inf_kwargs)

        pipe.model = deepspeed.init_inference(pipe.model,
                                              mp_size=self.world_size,
                                              dtype=dtype,
                                              replace_with_kernel_inject=True)
        check_injection(pipe.model)
        # Switch device to GPU so that input tensors are not on CPU
        pipe.device = torch.device(get_accelerator().device_name(local_rank))
        ds_output = pipe(query, **inf_kwargs)

        print(local_rank, "baseline", bs_output)
        print(local_rank, "deepspeed", ds_output)
        assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize(
    "model_w_task, injection_policy",
    [
aiss's avatar
aiss committed
362
363
364
365
366
367
        (("google/t5-v1_1-small", "text2text-generation"), {
            T5Block: ('SelfAttention.o', 'EncDecAttention.o', 'DenseReluDense.wo')
        }),
        (("roberta-large", "fill-mask"), {
            RobertaLayer: ('output.dense')
        }),
aiss's avatar
aiss committed
368
    ],
aiss's avatar
aiss committed
369
    ids=["t5", "roberta"],
aiss's avatar
aiss committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
)
@pytest.mark.parametrize("dtype", [torch.float], ids=["fp32"])
@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"])
class TestInjectionPolicy(DistributedTest):
    world_size = [1, 2]

    def test(
        self,
        model_w_task,
        injection_policy,
        query,
        inf_kwargs,
        assert_fn,
        invalid_model_task_config,
        dtype,
        enable_cuda_graph,
    ):
        if invalid_model_task_config:
            pytest.skip(invalid_model_task_config)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        world_size = int(os.getenv("WORLD_SIZE", "2"))

        # We have to load these large models on CPU with pipeline because not
        # enough GPU memory
        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        bs_output = pipe(query, **inf_kwargs)

        pipe.model = deepspeed.init_inference(pipe.model,
                                              mp_size=world_size,
                                              dtype=dtype,
                                              injection_policy=injection_policy)
        # Switch device to GPU so that input tensors are not on CPU
        pipe.device = torch.device(get_accelerator().device_name(local_rank))
        ds_output = pipe(query, **inf_kwargs)

        print(local_rank, "baseline", bs_output)
        print(local_rank, "deepspeed", ds_output)
        assert assert_fn(bs_output, ds_output)


@pytest.mark.seq_inference
@pytest.mark.parametrize(
    "model_w_task",
    [
aiss's avatar
aiss committed
416
        ("Helsinki-NLP/opus-mt-en-de", "translation"),
aiss's avatar
aiss committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    ],
    ids=[
        "marian",
    ],
)
@pytest.mark.parametrize("dtype", [torch.float16], ids=["fp16"])
@pytest.mark.parametrize("enable_cuda_graph", [False], ids=["noCG"])
class TestAutoTensorParallelism(DistributedTest):
    world_size = [2]

    def test(
        self,
        model_w_task,
        query,
        inf_kwargs,
        assert_fn,
        invalid_model_task_config,
        dtype,
        enable_cuda_graph,
    ):
        if invalid_model_task_config:
            pytest.skip(invalid_model_task_config)

        model, task = model_w_task
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        world_size = int(os.getenv("WORLD_SIZE", "2"))

        # We have to load these large models on CPU with pipeline because not
        # enough GPU memory
        pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
        bs_output = pipe(query, **inf_kwargs)

aiss's avatar
aiss committed
449
        pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
aiss's avatar
aiss committed
450
451
452
453
454
455
456
457
458
459
460
461
462
        # Switch device to GPU so that input tensors are not on CPU
        pipe.device = torch.device(get_accelerator().device_name(local_rank))
        ds_output = pipe(query, **inf_kwargs)

        print(local_rank, "baseline", bs_output)
        print(local_rank, "deepspeed", ds_output)
        assert assert_fn(bs_output, ds_output)


@pytest.mark.nightly
@pytest.mark.parametrize(
    "model_family, model_name",
    (
aiss's avatar
aiss committed
463
464
465
        ["gpt2", "EleutherAI/gpt-neo-2.7B"],
        ["gpt2", "EleutherAI/gpt-j-6B"],
        ["gpt2", "gpt2-xl"],
aiss's avatar
aiss committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    ),
)
@pytest.mark.parametrize("task", ["lambada_standard"])
class TestLMCorrectness(DistributedTest):
    world_size = 1

    def test(self, model_family, model_name, task):
        # imports here to avoid import errors when pytest collects tests
        import lm_eval
        import lm_eval.models
        import lm_eval.tasks
        import lm_eval.evaluator

        local_rank = os.getenv("LOCAL_RANK", "0")
        device = torch.device(get_accelerator().device_name(local_rank))
        dtype = torch.float
        task_dict = lm_eval.tasks.get_task_dict([task])

        if 'gpt-j-6B' in model_name:
            dtype = torch.half
aiss's avatar
aiss committed
486
487
            lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
                                                                               {"device": "cpu"})
aiss's avatar
aiss committed
488
489
490
491
            setattr(lm, model_family, getattr(lm, model_family).half().to(device))
            lm._device = device
        else:
            lm = lm_eval.models.get_model(model_family).create_from_arg_string(
aiss's avatar
aiss committed
492
                f"pretrained={model_name}", {"device": get_accelerator().device_name()})
aiss's avatar
aiss committed
493
494
495
496
497
498
499
500

        get_accelerator().synchronize()
        start = time.time()
        bs_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict)
        get_accelerator().synchronize()
        bs_time = time.time() - start

        ds_model = deepspeed.init_inference(
aiss's avatar
aiss committed
501
            getattr(lm, model_family),
aiss's avatar
aiss committed
502
503
504
505
506
507
508
509
510
511
512
513
514
            mp_size=1,
            dtype=dtype,
            replace_with_kernel_inject=True,
            enable_cuda_graph=False,
        )
        check_injection(ds_model)
        setattr(lm, model_family, ds_model)
        get_accelerator().synchronize()
        start = time.time()
        ds_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict)
        get_accelerator().synchronize()
        ds_time = time.time() - start

aiss's avatar
aiss committed
515
        ppl_diff = abs(bs_output["results"][task]["ppl"] - ds_output["results"][task]["ppl"])
aiss's avatar
aiss committed
516
517
        #assert ds_time <= bs_time
        assert ppl_diff < 0.01