"vscode:/vscode.git/clone" did not exist on "5027c4f95b65dbf76c3f72ce0fcda89716fb5f62"
test_ds_config_model.py 2.24 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16

import pytest
import os
import json
from pydantic import Field, ValidationError
from typing import List
from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


class SimpleConf(DeepSpeedConfigModel):
    param_1: int = 0
aiss's avatar
aiss committed
17
    param_2_old: str = Field(None, deprecated=True, new_param="param_2", new_param_fn=(lambda x: [x]))
aiss's avatar
aiss committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    param_2: List[str] = None
    param_3: int = Field(0, alias="param_3_alias")


def test_only_required_fields(tmpdir):
    '''Ensure that config containing only the required fields is accepted. '''
    cfg_json = tmpdir.mkdir('ds_config_unit_test').join('minimal.json')

    with open(cfg_json, 'w') as f:
        required_fields = {'train_batch_size': 64}
        json.dump(required_fields, f)

    run_cfg = ds_config.DeepSpeedConfig(cfg_json)
    assert run_cfg is not None
    assert run_cfg.train_batch_size == 64
    assert run_cfg.train_micro_batch_size_per_gpu == 64
    assert run_cfg.gradient_accumulation_steps == 1


def test_config_duplicate_key(tmpdir):
    config_dict = '''
    {
        "train_batch_size": 24,
        "train_batch_size": 24,
    }
    '''
    config_path = os.path.join(tmpdir, 'temp_config.json')

    with open(config_path, 'w') as jf:
        jf.write("%s" % config_dict)

    with pytest.raises(ValueError):
        run_cfg = ds_config.DeepSpeedConfig(config_path)


def test_config_base():
    config = SimpleConf(**{"param_1": 42})
    assert config.param_1 == 42


def test_config_base_deprecatedfield():
    config = SimpleConf(**{"param_2_old": "DS"})
    assert config.param_2 == ["DS"]


def test_config_base_aliasfield():
    config = SimpleConf(**{"param_3": 10})
    assert config.param_3 == 10

    config = SimpleConf(**{"param_3_alias": 10})
    assert config.param_3 == 10


aiss's avatar
aiss committed
71
@pytest.mark.parametrize("config_dict", [{"param_1": "DS"}, {"param_2": "DS"}, {"param_1_typo": 0}])
aiss's avatar
aiss committed
72
73
74
75
76
77
78
79
def test_config_base_literalfail(config_dict):
    with pytest.raises(ValidationError):
        config = SimpleConf(**config_dict)


def test_config_base_deprecatedfail():
    with pytest.raises(AssertionError):
        config = SimpleConf(**{"param_2": ["DS"], "param_2_old": "DS"})