test_torchscript.py 2.18 KB
Newer Older
Owen Wang's avatar
Owen Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import unittest

from d2go.export.torchscript import (
    MobileOptimizationConfig,
    update_export_kwargs_from_export_method,
)


@update_export_kwargs_from_export_method
def mock_export(cls, model, input_args, save_path, export_method, **export_kwargs):
    # Return the export kwargs, so that we can check to make sure it's set as expected
    return export_kwargs


class TestTorchscriptExportMethods(unittest.TestCase):
    def test_update_export_kwargs_from_export_method(self):

        _empty_export_kwargs = {}

        def try_mock_export(export_method: str, export_kwargs=_empty_export_kwargs):
            return mock_export(
                cls=None,
                model=None,
                input_args=None,
                save_path=None,
                export_method=export_method,
John Reese's avatar
John Reese committed
27
                **export_kwargs,
Owen Wang's avatar
Owen Wang committed
28
29
30
31
32
33
34
35
36
            )

        export_method_string = "torchscript"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertNotIn("mobile_optimization", new_export_kwargs)

        export_method_string = "torchscript_mobile"
        new_export_kwargs = try_mock_export(export_method_string)
        self.assertIn("mobile_optimization", new_export_kwargs)
37
        self.assertEqual(
Owen Wang's avatar
Owen Wang committed
38
39
40
            type(new_export_kwargs["mobile_optimization"]),
            MobileOptimizationConfig,
        )
41
        self.assertEqual(new_export_kwargs["mobile_optimization"].backend, "CPU")
Owen Wang's avatar
Owen Wang committed
42
43
44

        export_method_string = "torchscript_mobile-metal"
        new_export_kwargs = try_mock_export(export_method_string)
45
        self.assertEqual(new_export_kwargs["mobile_optimization"].backend, "metal")
Owen Wang's avatar
Owen Wang committed
46
47
48

        export_method_string = "torchscript_mobile-vulkan"
        new_export_kwargs = try_mock_export(export_method_string)
49
        self.assertEqual(new_export_kwargs["mobile_optimization"].backend, "vulkan")
Owen Wang's avatar
Owen Wang committed
50
51
52

        export_method_string = "torchscript_mobile@tracing"
        new_export_kwargs = try_mock_export(export_method_string)
53
        self.assertEqual(new_export_kwargs["jit_mode"], "trace")
Owen Wang's avatar
Owen Wang committed
54
55
56

        export_method_string = "torchscript_mobile@scripting"
        new_export_kwargs = try_mock_export(export_method_string)
57
        self.assertEqual(new_export_kwargs["jit_mode"], "script")