• Yanghan Wang's avatar
    enable moving traced model between devices · 2235f180
    Yanghan Wang authored
    Summary:
    X-link: https://github.com/facebookresearch/detectron2/pull/4132
    
    X-link: https://github.com/fairinternal/detectron2/pull/568
    
    Pull Request resolved: https://github.com/facebookresearch/d2go/pull/203
    
    For full discussion: https://fb.workplace.com/groups/1405155842844877/posts/5744470455580039
    
    Tracing the `.to(device)` will cause problem when moving the traced torchscript to another device (eg. from cpu to gpu, or even, from `cuda:0` to `cuda:1`). The reason is that `device` is not a `torch.Tensor`, so the tracer just hardcode the value during tracing. The solution is scripting the casting operation.
    
    Here's the code snippet illustrating this:
    ```
    # define the MyModel similar to GeneralizedRCNN, which casts the input to the model's device
    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.conv1 = nn.Conv2d(3, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            # cast the input to the same device as this model, this makes it possible to
            # take a cpu tensor as input when the model is on GPU.
            x = x.to(self.conv1.weight.device)
    
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))
    
    # export the model by tracing
    model = MyModel()
    x = torch.zeros([1, 3, 32, 32])
    ts = torch.jit.trace(model, x)
    print(ts.graph)
    
    # =====================================================
    graph(%self.1 : __torch__.MyModel,
          %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
      %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::GetAttr[name="conv2"](%self.1)
      %conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1)
      %14 : int = prim::Constant[value=6]() # <ipython-input-2-5abde0efc36f>:11:0
      %15 : int = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
      %16 : Device = prim::Constant[value="cpu"]() # <ipython-input-2-5abde0efc36f>:11:0
      %17 : NoneType = prim::Constant()
      %18 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
      %19 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
      %20 : NoneType = prim::Constant()
      %input.1 : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu) = aten::to(%x, %14, %15, %16, %17, %18, %19, %20) # <ipython-input-2-5abde0efc36f>:11:0
      %72 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1)
      %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%72) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
      %73 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5)
      %61 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%73) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
      return (%61)
    # =====================================================
    
    # PyTorch cuda works
    model = copy.deepcopy(model)
    model.to("cuda")
    y = model(x)
    # torchscript cpu works
    y = ts(x)
    # torchscript cuda doesn't work
    ts = ts.to("cuda")
    y = ts(x)
    
    # =====================================================
    RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-4-2aece3ad6c9a> in <module>
          7 # torchscript cuda doesn't work
          8 ts = ts.to("cuda")
    ----> 9 y = ts(x)
    /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
       1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1109                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1110             return forward_call(*input, **kwargs)
       1111         # Do not call functions when jit is used
       1112         full_backward_hooks, non_full_backward_hooks = [], []
    RuntimeError: The following operation failed in the TorchScript interpreter.
    # =====================================================
    
    # One solution is scripting the casting instead of tracing it, the folloing code demonstrate how to do it. We need to use mixed scripting/tracing
    torch.jit.script_if_tracing
    def cast_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
        return src.to(dst.device)
    
    class MyModel2(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.conv1 = nn.Conv2d(3, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            # cast the input to the same device as this model, this makes it possible to
            # take a cpu tensor as input when the model is on GPU.
            x = cast_device_like(x, self.conv1.weight)
    
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))
    
    # export the model by tracing
    model = MyModel2()
    x = torch.zeros([1, 3, 32, 32])
    ts = torch.jit.trace(model, x)
    print(ts.graph)
    
    # =====================================================
    graph(%self.1 : __torch__.MyModel2,
          %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
      %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_5.Conv2d = prim::GetAttr[name="conv2"](%self.1)
      %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1)
      %conv1.1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1)
      %weight.5 : Tensor = prim::GetAttr[name="weight"](%conv1.1)
      %14 : Function = prim::Constant[name="cast_device_like"]()
      %input.1 : Tensor = prim::CallFunction(%14, %x, %weight.5)
      %68 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1)
      %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%68) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
      %69 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5)
      %55 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%69) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
      return (%55)
    # =====================================================
    
    # PyTorch cuda works
    model = copy.deepcopy(model)
    model.to("cuda")
    y = model(x)
    # torchscript cpu works
    y = ts(x)
    # Note that now torchscript cuda works
    ts = ts.to("cuda")
    y = ts(x)
    print(y.device)
    
    # =====================================================
    cuda:0
    # =====================================================
    ```
    
    For D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb), this diff creates a `move_tensor_device_same_as_another(A, B)` function to replace `A.to(B.device)`. This diff updates the `rcnn.py` and all its utils.
    
    For D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go, since the exported model will become device-agnostic, we can remove the "_gpu" from predictor-type.
    
    Update (April 11):
    Add test to cover tracing on one device and move traced model to another device for inference. When GPU is available, it'll trace on `cuda:0` and run inference on `cpu`, `cuda:0` (and `cuda:N-1` if available).
    
    Summary of the device related patterns
    - The usage of `.to(dtype=another_dype)` won't affect device.
    - Explicit device casting like `.to(device)` can be generally replaced by `move_device_like`.
    - For creating variable directly on device (eg. `torch.zeros`, `torch.arange`), we can replace then with ScriptModule to avoid first create on CPU and then move to new device.
        - Creating things on tracing device and then moving to new device is dangerous, because tracing device (eg. `cuda:0`) might not be available (eg. running on CPU-only machine).
        - It's hard to write `image_list.py` in this pattern because the size behaves differently during tracing (int vs. scalar tensor), in this diff, still create on CPU first and then move to target device.
    
    Reviewed By: tglik
    
    Differential Revision: D35367772
    
    fbshipit-source-id: 02d07e3d96da85f4cfbeb996e3c14c2a6f619beb
    2235f180
test_rcnn_export_example.py 4.59 KB