"vscode:/vscode.git/clone" did not exist on "4067d6c4b64f2b606f9806d4a8b15d5fd5cbea1e"
test_torchscript.py 2.19 KB
Newer Older
Owen Wang's avatar
Owen Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import unittest

from d2go.export.torchscript import (
    MobileOptimizationConfig,
    update_export_kwargs_from_export_method,
)


@update_export_kwargs_from_export_method
def mock_export(cls, model, input_args, save_path, export_method, **export_kwargs):
    # Return the export kwargs, so that we can check to make sure it's set as expected
    return export_kwargs


class TestTorchscriptExportMethods(unittest.TestCase):
    def test_update_export_kwargs_from_export_method(self):

        _empty_export_kwargs = {}

        def try_mock_export(export_method: str, export_kwargs=_empty_export_kwargs):
            return mock_export(
                cls=None,
                model=None,
                input_args=None,
                save_path=None,
                export_method=export_method,
                **export_kwargs
            )

        export_method_string = "torchscript"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertNotIn("mobile_optimization", new_export_kwargs)

        export_method_string = "torchscript_mobile"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertIn("mobile_optimization", new_export_kwargs)
        self.assertEquals(
            type(new_export_kwargs["mobile_optimization"]),
            MobileOptimizationConfig,
        )
        self.assertEquals(new_export_kwargs["mobile_optimization"].backend, "CPU")

        export_method_string = "torchscript_mobile-metal"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertEquals(new_export_kwargs["mobile_optimization"].backend, "metal")

        export_method_string = "torchscript_mobile-vulkan"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertEquals(new_export_kwargs["mobile_optimization"].backend, "vulkan")

        export_method_string = "torchscript_mobile@tracing"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertEquals(new_export_kwargs["jit_mode"], "trace")

        export_method_string = "torchscript_mobile@scripting"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertEquals(new_export_kwargs["jit_mode"], "script")