Commit 6ca4702b authored by Simon Hollis's avatar Simon Hollis Committed by Facebook GitHub Bot
Browse files

Fix import exceptions in tracing.py for older (<1.12) versions of pytorch

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/362

X-link: https://github.com/facebookresearch/detectron2/pull/4491

Recently landed D35518556 (https://github.com/facebookresearch/d2go/commit/1ffc801bf5a1d4fe926b815ba93f21632f0980f9) / Github: 36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0

This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them.

Reviewed By: wat3rBro

Differential Revision: D38879134

fbshipit-source-id: 72f5a7a8d350eb82be87567f006368bf207f5a74
parent 53f9eee2
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from detectron2 import layers from detectron2 import layers
from detectron2.utils.tracing import assert_fx_safe from detectron2.utils.tracing import is_fx_tracing
from mobile_cv.arch.fbnet_v2.irf_block import IRFBlock from mobile_cv.arch.fbnet_v2.irf_block import IRFBlock
...@@ -34,7 +33,8 @@ class RPNHeadConvRegressor(nn.Module): ...@@ -34,7 +33,8 @@ class RPNHeadConvRegressor(nn.Module):
torch.nn.init.constant_(l.bias, 0) torch.nn.init.constant_(l.bias, 0)
def forward(self, x: List[torch.Tensor]): def forward(self, x: List[torch.Tensor]):
assert_fx_safe(isinstance(x, (list, tuple)), "Unexpected data type") if not is_fx_tracing():
torch._assert(isinstance(x, (list, tuple)), "Unexpected data type")
logits = [self.cls_logits(y) for y in x] logits = [self.cls_logits(y) for y in x]
bbox_reg = [self.bbox_pred(y) for y in x] bbox_reg = [self.bbox_pred(y) for y in x]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment