Commit 6ca4702b authored by Simon Hollis's avatar Simon Hollis Committed by Facebook GitHub Bot
Browse files

Fix import exceptions in tracing.py for older (<1.12) versions of pytorch

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/362

X-link: https://github.com/facebookresearch/detectron2/pull/4491

Recently landed D35518556 (https://github.com/facebookresearch/d2go/commit/1ffc801bf5a1d4fe926b815ba93f21632f0980f9) / Github: 36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 throws an exception with older versions of PyTorch, due to a missing library for import. This has been reported by multiple members of the PyTorch community at https://github.com/facebookresearch/detectron2/commit/36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0

This change uses `try/except` to check for libraries and set flags on presence/absence to later guard code that would use them.

Reviewed By: wat3rBro

Differential Revision: D38879134

fbshipit-source-id: 72f5a7a8d350eb82be87567f006368bf207f5a74
parent 53f9eee2
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
import torch
import torch.nn as nn
from detectron2 import layers
from detectron2.utils.tracing import assert_fx_safe
from detectron2.utils.tracing import is_fx_tracing
from mobile_cv.arch.fbnet_v2.irf_block import IRFBlock
......@@ -34,7 +33,8 @@ class RPNHeadConvRegressor(nn.Module):
torch.nn.init.constant_(l.bias, 0)
def forward(self, x: List[torch.Tensor]):
assert_fx_safe(isinstance(x, (list, tuple)), "Unexpected data type")
if not is_fx_tracing():
torch._assert(isinstance(x, (list, tuple)), "Unexpected data type")
logits = [self.cls_logits(y) for y in x]
bbox_reg = [self.bbox_pred(y) for y in x]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment