Commit 2235f180 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable moving traced model between devices

Summary:
X-link: https://github.com/facebookresearch/detectron2/pull/4132

X-link: https://github.com/fairinternal/detectron2/pull/568

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/203

For full discussion: https://fb.workplace.com/groups/1405155842844877/posts/5744470455580039

Tracing the `.to(device)` will cause problem when moving the traced torchscript to another device (eg. from cpu to gpu, or even, from `cuda:0` to `cuda:1`). The reason is that `device` is not a `torch.Tensor`, so the tracer just hardcode the value during tracing. The solution is scripting the casting operation.

Here's the code snippet illustrating this:
```
# define the MyModel similar to GeneralizedRCNN, which casts the input to the model's device
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        # cast the input to the same device as this model, this makes it possible to
        # take a cpu tensor as input when the model is on GPU.
        x = x.to(self.conv1.weight.device)

        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

# export the model by tracing
model = MyModel()
x = torch.zeros([1, 3, 32, 32])
ts = torch.jit.trace(model, x)
print(ts.graph)

# =====================================================
graph(%self.1 : __torch__.MyModel,
      %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %14 : int = prim::Constant[value=6]() # <ipython-input-2-5abde0efc36f>:11:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
  %16 : Device = prim::Constant[value="cpu"]() # <ipython-input-2-5abde0efc36f>:11:0
  %17 : NoneType = prim::Constant()
  %18 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
  %19 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
  %20 : NoneType = prim::Constant()
  %input.1 : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu) = aten::to(%x, %14, %15, %16, %17, %18, %19, %20) # <ipython-input-2-5abde0efc36f>:11:0
  %72 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1)
  %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%72) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  %73 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5)
  %61 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%73) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  return (%61)
# =====================================================

# PyTorch cuda works
model = copy.deepcopy(model)
model.to("cuda")
y = model(x)
# torchscript cpu works
y = ts(x)
# torchscript cuda doesn't work
ts = ts.to("cuda")
y = ts(x)

# =====================================================
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-2aece3ad6c9a> in <module>
      7 # torchscript cuda doesn't work
      8 ts = ts.to("cuda")
----> 9 y = ts(x)
/mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []
RuntimeError: The following operation failed in the TorchScript interpreter.
# =====================================================

# One solution is scripting the casting instead of tracing it, the folloing code demonstrate how to do it. We need to use mixed scripting/tracing
torch.jit.script_if_tracing
def cast_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
    return src.to(dst.device)

class MyModel2(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        # cast the input to the same device as this model, this makes it possible to
        # take a cpu tensor as input when the model is on GPU.
        x = cast_device_like(x, self.conv1.weight)

        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

# export the model by tracing
model = MyModel2()
x = torch.zeros([1, 3, 32, 32])
ts = torch.jit.trace(model, x)
print(ts.graph)

# =====================================================
graph(%self.1 : __torch__.MyModel2,
      %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_5.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %conv1.1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %weight.5 : Tensor = prim::GetAttr[name="weight"](%conv1.1)
  %14 : Function = prim::Constant[name="cast_device_like"]()
  %input.1 : Tensor = prim::CallFunction(%14, %x, %weight.5)
  %68 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1)
  %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%68) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  %69 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5)
  %55 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%69) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  return (%55)
# =====================================================

# PyTorch cuda works
model = copy.deepcopy(model)
model.to("cuda")
y = model(x)
# torchscript cpu works
y = ts(x)
# Note that now torchscript cuda works
ts = ts.to("cuda")
y = ts(x)
print(y.device)

# =====================================================
cuda:0
# =====================================================
```

For D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb), this diff creates a `move_tensor_device_same_as_another(A, B)` function to replace `A.to(B.device)`. This diff updates the `rcnn.py` and all its utils.

For D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go, since the exported model will become device-agnostic, we can remove the "_gpu" from predictor-type.

Update (April 11):
Add test to cover tracing on one device and move traced model to another device for inference. When GPU is available, it'll trace on `cuda:0` and run inference on `cpu`, `cuda:0` (and `cuda:N-1` if available).

Summary of the device related patterns
- The usage of `.to(dtype=another_dype)` won't affect device.
- Explicit device casting like `.to(device)` can be generally replaced by `move_device_like`.
- For creating variable directly on device (eg. `torch.zeros`, `torch.arange`), we can replace then with ScriptModule to avoid first create on CPU and then move to new device.
    - Creating things on tracing device and then moving to new device is dangerous, because tracing device (eg. `cuda:0`) might not be available (eg. running on CPU-only machine).
    - It's hard to write `image_list.py` in this pattern because the size behaves differently during tracing (int vs. scalar tensor), in this diff, still create on CPU first and then move to target device.

Reviewed By: tglik

Differential Revision: D35367772

fbshipit-source-id: 02d07e3d96da85f4cfbeb996e3c14c2a6f619beb
parent 8d5c70e9
......@@ -62,14 +62,6 @@ class GeneralizedRCNNPatch:
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
pytorch_model = self
# NOTE: currently Exporter doesn't support specifying exporting GPU model via
# `model_export_method` in a general way. For RCNN model, we only need to cast
# the model to GPU and trace the model (scripting might not work) normally to
# get the GPU torchscripts.
if "_gpu" in predictor_type:
pytorch_model = _cast_detection_model(pytorch_model, "cuda")
predictor_type = predictor_type.replace("_gpu", "", 1)
if (
"@c2_ops" in predictor_type
or "caffe2" in predictor_type
......
......@@ -183,10 +183,8 @@ class MockRCNNInference(object):
return results
def _validate_outputs(inputs, outputs, is_gpu=False):
def _validate_outputs(inputs, outputs):
assert len(inputs) == len(outputs)
if is_gpu:
assert outputs[0]["instances"].pred_classes.device.type == "cuda"
# TODO: figure out how to validate outputs
......@@ -353,8 +351,7 @@ class RCNNBaseTestCases:
predictor = create_predictor(predictor_path)
predictor_outputs = predictor(inputs)
is_gpu = self.cfg.MODEL.DEVICE != "cpu" or "_gpu" in predictor_type
_validate_outputs(inputs, predictor_outputs, is_gpu=is_gpu)
_validate_outputs(inputs, predictor_outputs)
if compare_match:
with torch.no_grad():
......
......@@ -41,7 +41,6 @@ class TestFBNetV3MaskRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
[
["torchscript@c2_ops", True],
["torchscript", True],
["torchscript_gpu", False], # can't compare across device
["torchscript_int8@c2_ops", False],
["torchscript_int8", False],
]
......
......@@ -12,7 +12,7 @@ from d2go.utils.testing.rcnn_helper import get_quick_test_config_opts
from mobile_cv.common.misc.file_utils import make_temp_directory
def maskrcnn_export_caffe2_vs_torchvision_opset_format_example():
def maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
with make_temp_directory("export_demo") as tmp_dir:
# use a fake dataset for ci
dataset_name = create_local_dataset(tmp_dir, 5, 224, 224)
......@@ -48,6 +48,9 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example():
# Running inference using torchvision-style format
image = torch.zeros(1, 64, 96) # chw 3D tensor
# The exported model can run on both cpu/gpu
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torchvision_ops_model = torchvision_ops_model.to(device)
torchvision_style_outputs = torchvision_ops_model(
image
) # suppose N instances are detected
......@@ -56,10 +59,20 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example():
# be difficult to figure out just from model.jit file. The predictor_info.json from
# the same directory contains the `outputs_schema`, which indicate how the final output
# is constructed from flattened tensors.
pred_boxes = torchvision_style_outputs[0] # torch.Size([N, 4])
pred_classes = torchvision_style_outputs[1] # torch.Size([N])
pred_masks = torchvision_style_outputs[2] # torch.Size([N, 1, Hmask, Wmask])
scores = torchvision_style_outputs[3] # torch.Size([N])
(
pred_boxes, # torch.Size([N, 4])
pred_classes, # torch.Size([N])
pred_masks, # torch.Size([N, 1, Hmask, Wmask])
scores, # torch.Size([N])
image_sizes, # torch.Size([2])
) = torchvision_style_outputs
self.assertTrue(
all(
x.device == torch.device(device) for x in torchvision_style_outputs[:4]
),
torchvision_style_outputs,
)
torch.testing.assert_close(image_sizes, torch.tensor([64, 96]))
# Running inference using caffe2-style format
data = torch.zeros(1, 1, 64, 96)
......@@ -73,16 +86,20 @@ def maskrcnn_export_caffe2_vs_torchvision_opset_format_example():
mask_fcn_probs = caffe2_style_outputs[3] # torch.Size([N, Cmask, Hmask, Wmask])
# relations between torchvision-style outputs and caffe2-style outputs
torch.testing.assert_allclose(pred_boxes, roi_bbox_nms)
torch.testing.assert_allclose(pred_classes, roi_class_nms)
torch.testing.assert_allclose(
pred_masks, mask_fcn_probs[:, roi_class_nms.to(torch.int64), :, :]
torch.testing.assert_close(pred_boxes, roi_bbox_nms, check_device=False)
torch.testing.assert_close(
pred_classes, roi_class_nms.to(torch.int64), check_device=False
)
torch.testing.assert_close(
pred_masks,
mask_fcn_probs[:, roi_class_nms.to(torch.int64), :, :],
check_device=False,
)
torch.testing.assert_allclose(scores, roi_score_nms)
torch.testing.assert_close(scores, roi_score_nms, check_device=False)
# END_WIKI_EXAMPLE_TAG
class TestOptimizer(unittest.TestCase):
@unittest.skipIf(os.getenv("OSSRUN") == "1", "Caffe2 is not available for OSS")
def test_maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self):
maskrcnn_export_caffe2_vs_torchvision_opset_format_example()
maskrcnn_export_caffe2_vs_torchvision_opset_format_example(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment