Commit 6b4dbb31 authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

add metal GPU to d2go export

Summary: Allow string name of export type to indicate which mobile opt backend user wants to trigger.

Reviewed By: wat3rBro

Differential Revision: D35375928

fbshipit-source-id: dc3f91564681344e1d43862423ab5dc63b6644d3
parent 647a3fdf
......@@ -298,6 +298,30 @@ def update_export_kwargs_from_export_method(old_f):
model_export_kwargs of the PredictorExportConfig, user can simply put the `_mobile`
trigger word in the --predictor-type (which will then be forwarded as `export_method`
in most cases) to enable mobile optimizaiton.
Please note that there's a finite set of allowed "export_method" values,
and an error will be raised if the string cannot be fully parsed.
The recognized values generally follow a pattern of:
"torchscript[_mobile][_int8][-vulkan | -metal][@scripting | @tracing]"
Some examples (not comprehensive because flag words' order can be swapped):
"torchscript"
"torchscript_mobile"
"torchscript_mobile-metal"
"torchscript_mobile-vulkan"
"torchscript_mobile_int8"
"torchscript@scripting"
"torchscript_int8@scripting"
"torchscript_mobile@scripting"
"torchscript_mobile-metal@scripting"
"torchscript_mobile-vulkan@scripting"
"torchscript_mobile_int8@scripting"
"torchscript@tracing"
"torchscript_int8@tracing"
"torchscript_mobile@tracing"
"torchscript_mobile-metal@tracing"
"torchscript_mobile-vulkan@tracing"
"torchscript_mobile_int8@tracing"
"""
def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
......@@ -311,7 +335,18 @@ def update_export_kwargs_from_export_method(old_f):
"`mobile_optimization` is already specified, keep using it"
)
else:
export_kwargs["mobile_optimization"] = MobileOptimizationConfig()
# Infer a MobileOptimizationConfig if none was provided
# "CPU" backend default. If found appropriate suffix, update the backend
if "-metal" in export_method:
mobile_opt_config = MobileOptimizationConfig(backend="metal")
export_method.replace("-metal", "", 1)
elif "-vulkan" in export_method:
mobile_opt_config = MobileOptimizationConfig(backend="vulkan")
export_method.replace("-vulkan", "", 1)
else:
mobile_opt_config = MobileOptimizationConfig()
export_kwargs["mobile_optimization"] = mobile_opt_config
export_method = export_method.replace("_mobile", "", 1)
if "@scripting" in export_method:
......@@ -375,14 +410,20 @@ class DefaultTorchscriptExport(ModelExportMethod):
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile-metal")
@ModelExportMethodRegistry.register("torchscript_mobile-vulkan")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
@ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile-metal@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile-vulkan@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@scripting")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile-metal@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile-vulkan@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@tracing")
class TracingAdaptedTorchscriptExport(DefaultTorchscriptExport):
@classmethod
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import unittest
from typing import List
import torch
import torch.nn as nn
from d2go.export.api import (
convert_and_export_predictor,
PredictorExportConfig,
FuncInfo,
)
from d2go.export.fb.caffe2 import DefaultCaffe2Export
from d2go.export.torchscript import (
DefaultTorchscriptExport,
TracingAdaptedTorchscriptExport,
)
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.predictor.api import create_predictor
from parameterized import parameterized
class SimpleModel(nn.Module):
def forward(self, x):
return 2 * x
def prepare_for_export(self, cfg, inputs, predictor_type):
# pre/post processing and run_func are default values
return PredictorExportConfig(
model=self,
# model(x) -> model(*(x,))
data_generator=lambda x: (x,),
)
class TwoPartSimpleModel(nn.Module):
"""
Suppose there're some function in the middle that can't be traced, therefore we
need to export the model as two parts.
"""
def __init__(self):
super().__init__()
self.part1 = SimpleModel()
self.part2 = SimpleModel()
def forward(self, x):
x = self.part1(x)
x = TwoPartSimpleModel.non_traceable_func(x)
x = self.part2(x)
return x
def prepare_for_export(self, cfg, inputs, predictor_type):
def data_generator(x):
part1_args = (x,)
x = self.part1(x)
x = TwoPartSimpleModel.non_traceable_func(x)
part2_args = (x,)
return {"part1": part1_args, "part2": part2_args}
return PredictorExportConfig(
model={"part1": self.part1, "part2": self.part2},
data_generator=data_generator,
run_func_info=FuncInfo.gen_func_info(TwoPartSimpleModel.RunFunc, params={}),
)
@staticmethod
def non_traceable_func(x):
return x + 1 if len(x.shape) > 3 else x - 1
class RunFunc(object):
def __call__(self, model, x):
assert isinstance(model, dict)
x = model["part1"](x)
x = TwoPartSimpleModel.non_traceable_func(x)
x = model["part2"](x)
return x
class ScriptingOnlyModel(nn.Module):
"""
Example of a model that requires scripting (eg. having control loop).
"""
def forward(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
outputs = []
for i, t in enumerate(inputs):
outputs.append(t * i)
return outputs
def prepare_for_export(self, cfg, inputs, predictor_type):
if cfg == "explicit":
return PredictorExportConfig(
model=self,
data_generator=None, # data is not needed for scripting
model_export_kwargs={
"jit_mode": "script"
}, # explicitly using script mode
)
elif cfg == "implicit":
# Sometime user wants to switch between scripting and tracing without
# touching the PredictorExportConfig
return PredictorExportConfig(
model=self,
data_generator=None, # data is not needed for scripting
)
raise NotImplementedError()
class TestExportAPI(unittest.TestCase):
def _export_simple_model(self, cfg, model, data, output_dir, predictor_type):
predictor_path = convert_and_export_predictor(
cfg,
model,
predictor_type=predictor_type,
output_dir=output_dir,
data_loader=iter([data] * 3),
)
self.assertTrue(os.path.isdir(predictor_path))
# also test loading predictor
predictor = create_predictor(predictor_path)
return predictor
def test_simple_model(self):
with make_temp_directory("test_simple_model") as tmp_dir:
model = SimpleModel()
predictor = self._export_simple_model(
None, model, torch.tensor(1), tmp_dir, predictor_type="torchscript"
)
x = torch.tensor(42)
self.assertEqual(predictor(x), model(x))
def test_simple_two_part_model(self):
with make_temp_directory("test_simple_two_part_model") as tmp_dir:
model = TwoPartSimpleModel()
predictor = self._export_simple_model(
None, model, torch.tensor(1), tmp_dir, predictor_type="torchscript"
)
x = torch.tensor(42)
self.assertEqual(predictor(x), model(x))
def test_script_only_model(self):
def _validate(predictor):
outputs = predictor([torch.tensor(1), torch.tensor(2), torch.tensor(3)])
self.assertEqual(len(outputs), 3)
self.assertEqual(
outputs, [torch.tensor(0), torch.tensor(2), torch.tensor(6)]
)
# Method 1: explicitly set jit_mode to "trace"
with make_temp_directory("test_test_script_only_model") as tmp_dir:
model = ScriptingOnlyModel()
predictor = self._export_simple_model(
"explicit", model, None, tmp_dir, predictor_type="torchscript"
)
_validate(predictor)
# Method 2: using torchscript@scripting as predictor type
with make_temp_directory("test_test_script_only_model") as tmp_dir:
model = ScriptingOnlyModel()
predictor = self._export_simple_model(
"implicit", model, None, tmp_dir, predictor_type="torchscript@scripting"
)
_validate(predictor)
class MultiTensorInSingleTensorOut(nn.Module):
def forward(self, x, y):
return x + y
@staticmethod
def get_input_args():
return (torch.tensor([2]), torch.tensor([3]))
@staticmethod
def check_outputs(new_output, original_output):
torch.testing.assert_allclose(new_output, torch.tensor([5]))
# NOTE: caffe2 wrapper assumes tensors are fp32
class SingleListInSingleListOut(nn.Module):
def forward(self, inputs):
x, y = inputs
return [x + y]
@staticmethod
def get_input_args():
inputs = [torch.tensor([2.0]), torch.tensor([3.0])]
return (inputs,)
@staticmethod
def check_outputs(new_output, original_output):
assert len(new_output) == 1
torch.testing.assert_allclose(new_output[0], torch.tensor([5.0]))
class MultiDictInMultiDictOut(nn.Module):
def forward(self, x, y):
first = {"add": x["first"] + y["first"], "sub": x["first"] - y["first"]}
second = {"add": x["second"] + y["second"], "sub": x["second"] - y["second"]}
return [first, second]
@staticmethod
def get_input_args():
return (
{"first": torch.tensor([1]), "second": torch.tensor([2])}, # x
{"first": torch.tensor([3]), "second": torch.tensor([4])}, # y
)
@staticmethod
def check_outputs(new_output, original_output):
first, second = original_output
torch.testing.assert_allclose(first["add"], torch.tensor([4]))
torch.testing.assert_allclose(first["sub"], torch.tensor([-2]))
torch.testing.assert_allclose(second["add"], torch.tensor([6]))
torch.testing.assert_allclose(second["sub"], torch.tensor([-2]))
class TestModelExportMethods(unittest.TestCase):
@parameterized.expand(
[
[DefaultTorchscriptExport, MultiTensorInSingleTensorOut],
[DefaultTorchscriptExport, SingleListInSingleListOut],
# [DefaultCaffe2Export, MultiTensorInSingleTensorOut], # TODO: make caffe2 support this
[DefaultCaffe2Export, SingleListInSingleListOut],
[TracingAdaptedTorchscriptExport, MultiTensorInSingleTensorOut],
[TracingAdaptedTorchscriptExport, SingleListInSingleListOut],
[TracingAdaptedTorchscriptExport, MultiDictInMultiDictOut],
],
name_func=lambda testcase_func, param_num, param: (
"{}_{}_{}".format(
testcase_func.__name__, param.args[0].__name__, param.args[1].__name__
)
),
)
def test_interface(self, model_export_method, test_model_class):
model = test_model_class()
input_args = test_model_class.get_input_args()
output_checker = test_model_class.check_outputs
model_export_method.test_export_and_load(
model, input_args, None, {}, output_checker
)
import unittest
from d2go.export.torchscript import (
MobileOptimizationConfig,
update_export_kwargs_from_export_method,
)
@update_export_kwargs_from_export_method
def mock_export(cls, model, input_args, save_path, export_method, **export_kwargs):
# Return the export kwargs, so that we can check to make sure it's set as expected
return export_kwargs
class TestTorchscriptExportMethods(unittest.TestCase):
def test_update_export_kwargs_from_export_method(self):
_empty_export_kwargs = {}
def try_mock_export(export_method: str, export_kwargs=_empty_export_kwargs):
return mock_export(
cls=None,
model=None,
input_args=None,
save_path=None,
export_method=export_method,
**export_kwargs
)
export_method_string = "torchscript"
new_export_kwargs = try_mock_export(export_method_string)
self.assertNotIn("mobile_optimization", new_export_kwargs)
export_method_string = "torchscript_mobile"
new_export_kwargs = try_mock_export(export_method_string)
self.assertIn("mobile_optimization", new_export_kwargs)
self.assertEquals(
type(new_export_kwargs["mobile_optimization"]),
MobileOptimizationConfig,
)
self.assertEquals(new_export_kwargs["mobile_optimization"].backend, "CPU")
export_method_string = "torchscript_mobile-metal"
new_export_kwargs = try_mock_export(export_method_string)
self.assertEquals(new_export_kwargs["mobile_optimization"].backend, "metal")
export_method_string = "torchscript_mobile-vulkan"
new_export_kwargs = try_mock_export(export_method_string)
self.assertEquals(new_export_kwargs["mobile_optimization"].backend, "vulkan")
export_method_string = "torchscript_mobile@tracing"
new_export_kwargs = try_mock_export(export_method_string)
self.assertEquals(new_export_kwargs["jit_mode"], "trace")
export_method_string = "torchscript_mobile@scripting"
new_export_kwargs = try_mock_export(export_method_string)
self.assertEquals(new_export_kwargs["jit_mode"], "script")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment