test_utils_file_handler.py 2.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for Azure services utilities."""

import unittest
from unittest import mock
from pathlib import Path

import yaml
from omegaconf import OmegaConf

from superbench.common.utils import get_sb_config


class FileHandlerUtilsTestCase(unittest.TestCase):
    """A class for file_handler test cases."""
    @mock.patch('superbench.common.utils.azure.get_azure_imds')
    def test_get_sb_config_default(self, mock_get_azure_imds):
        """Test get_sb_config when no SKU detected, should use default config.

        Args:
            mock_get_azure_imds (function): Mock get_azure_imds function.
        """
        mock_get_azure_imds.return_value = ''
        with (Path.cwd() / 'superbench/config/default.yaml').open() as fp:
            self.assertEqual(get_sb_config(None), OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader)))

    @mock.patch('superbench.common.utils.azure.get_azure_imds')
    def test_get_sb_config_sku(self, mock_get_azure_imds):
        """Test get_sb_config when SKU detected and config exists, should use corresponding config.

        Args:
            mock_get_azure_imds (function): Mock get_azure_imds function.
        """
        mock_get_azure_imds.return_value = 'Standard_NC96ads_A100_v4'
        with (Path.cwd() / 'superbench/config/azure/inference/standard_nc96ads_a100_v4.yaml').open() as fp:
            self.assertEqual(get_sb_config(None), OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader)))

    @mock.patch('superbench.common.utils.azure.get_azure_imds')
    def test_get_sb_config_sku_nonexist(self, mock_get_azure_imds):
        """Test get_sb_config when SKU detected and no config exists, should use default config.

        Args:
            mock_get_azure_imds (function): Mock get_azure_imds function.
        """
        mock_get_azure_imds.return_value = 'Standard_Nonexist_A100_v4'
        with (Path.cwd() / 'superbench/config/default.yaml').open() as fp:
            self.assertEqual(get_sb_config(None), OmegaConf.create(yaml.load(fp, Loader=yaml.SafeLoader)))