test_ds_config_model.py 2.52 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
'''Copyright The Microsoft DeepSpeed Team'''

import pytest
import os
import json
from pydantic import Field, ValidationError
from typing import List
from deepspeed.runtime import config as ds_config
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


class SimpleConf(DeepSpeedConfigModel):
    param_1: int = 0
    param_2_old: str = Field(None,
                             deprecated=True,
                             new_param="param_2",
                             new_param_fn=(lambda x: [x]))
    param_2: List[str] = None
    param_3: int = Field(0, alias="param_3_alias")


def test_only_required_fields(tmpdir):
    '''Ensure that config containing only the required fields is accepted. '''
    cfg_json = tmpdir.mkdir('ds_config_unit_test').join('minimal.json')

    with open(cfg_json, 'w') as f:
        required_fields = {'train_batch_size': 64}
        json.dump(required_fields, f)

    run_cfg = ds_config.DeepSpeedConfig(cfg_json)
    assert run_cfg is not None
    assert run_cfg.train_batch_size == 64
    assert run_cfg.train_micro_batch_size_per_gpu == 64
    assert run_cfg.gradient_accumulation_steps == 1


def test_config_duplicate_key(tmpdir):
    config_dict = '''
    {
        "train_batch_size": 24,
        "train_batch_size": 24,
    }
    '''
    config_path = os.path.join(tmpdir, 'temp_config.json')

    with open(config_path, 'w') as jf:
        jf.write("%s" % config_dict)

    with pytest.raises(ValueError):
        run_cfg = ds_config.DeepSpeedConfig(config_path)


def test_config_base():
    config = SimpleConf(**{"param_1": 42})
    assert config.param_1 == 42


def test_config_base_deprecatedfield():
    config = SimpleConf(**{"param_2_old": "DS"})
    assert config.param_2 == ["DS"]


def test_config_base_aliasfield():
    config = SimpleConf(**{"param_3": 10})
    assert config.param_3 == 10

    config = SimpleConf(**{"param_3_alias": 10})
    assert config.param_3 == 10


@pytest.mark.parametrize("config_dict",
                         [{
                             "param_1": "DS"
                         },
                          {
                              "param_2": "DS"
                          },
                          {
                              "param_1_typo": 0
                          }])
def test_config_base_literalfail(config_dict):
    with pytest.raises(ValidationError):
        config = SimpleConf(**config_dict)


def test_config_base_deprecatedfail():
    with pytest.raises(AssertionError):
        config = SimpleConf(**{"param_2": ["DS"], "param_2_old": "DS"})