test_checkpoint_sharding.py 3.05 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11

import os
import pytest
import torch
import deepspeed
from deepspeed.model_implementations import DeepSpeedTransformerInference
from unit.common import DistributedTest, DistributedFixture
aiss's avatar
aiss committed
12
from transformers import AutoConfig, AutoModelForCausalLM
aiss's avatar
aiss committed
13
14
15


def check_dtype(model, expected_dtype):
aiss's avatar
aiss committed
16

aiss's avatar
aiss committed
17
18
19
20
21
22
23
24
25
26
27
    def find_dtype(module):
        for child in module.children():
            if isinstance(child, DeepSpeedTransformerInference):
                return child.attention.attn_qkvw.dtype
            else:
                found_dtype = find_dtype(child)
                if found_dtype:
                    return found_dtype

    found_dtype = find_dtype(model)
    assert found_dtype, "Did not find DeepSpeedTransformerInference in model"
aiss's avatar
aiss committed
28
    assert (found_dtype == expected_dtype), f"Expected transformer dtype {expected_dtype}, but found {found_dtype}"
aiss's avatar
aiss committed
29
30


aiss's avatar
aiss committed
31
32
@pytest.fixture(
    params=["bigscience/bloom-560m", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-125M", "facebook/opt-125m"])
aiss's avatar
aiss committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def model_name(request):
    return request.param


@pytest.fixture(params=[torch.float16, torch.int8], ids=["fp16", "int8"])
def dtype(request):
    return request.param


class save_shard(DistributedFixture):
    world_size = 2

    def run(self, model_name, class_tmpdir):
        # Only write a checkpoint if one does not exist
        if not os.path.isdir(os.path.join(class_tmpdir, model_name)):
            world_size = int(os.getenv("WORLD_SIZE", "1"))
            inf_config = {
                "replace_with_kernel_inject": True,
                "dtype": torch.float16,
                "enable_cuda_graph": False,
                "tensor_parallel": {
                    "tp_size": world_size
                },
aiss's avatar
aiss committed
56
                "save_mp_checkpoint_path": os.path.join(str(class_tmpdir), model_name),
aiss's avatar
aiss committed
57
58
59
            }

            # Load model and save sharded checkpoint
aiss's avatar
aiss committed
60
            model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
aiss's avatar
aiss committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            model = deepspeed.init_inference(model, config=inf_config)


@pytest.mark.seq_inference
class TestCheckpointShard(DistributedTest):
    world_size = 2

    def test(self, model_name, dtype, class_tmpdir, save_shard):
        world_size = int(os.getenv("WORLD_SIZE", "1"))
        inf_config = {
            "replace_with_kernel_inject": True,
            "dtype": dtype,
            "enable_cuda_graph": False,
            "tensor_parallel": {
                "tp_size": world_size
            },
aiss's avatar
aiss committed
77
            "checkpoint": os.path.join(class_tmpdir, model_name, "ds_inference_config.json"),
aiss's avatar
aiss committed
78
79
80
81
82
83
        }

        # Load model on meta tensors
        model_config = AutoConfig.from_pretrained(model_name)
        # Note that we use half precision to load initially, even for int8
        with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
aiss's avatar
aiss committed
84
            model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16)
aiss's avatar
aiss committed
85
86
87
        model = model.eval()
        model = deepspeed.init_inference(model, config=inf_config)
        check_dtype(model, dtype)