Commit 03421e87 authored by Timothee Cour's avatar Timothee Cour Committed by mcarilli
Browse files
parent 3ae89c75
...@@ -74,9 +74,12 @@ if version_num < 1.1: ...@@ -74,9 +74,12 @@ if version_num < 1.1:
_bmms = ['addbmm', _bmms = ['addbmm',
'baddbmm', 'baddbmm',
'bmm'] 'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
if utils.is_cuda_enabled():
# workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms) FP16_FUNCS.extend(_bmms)
else: else:
FP32_FUNCS.extend(_bmms) FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion # Multi-tensor fns that may need type promotion
......
...@@ -5,6 +5,9 @@ import itertools ...@@ -5,6 +5,9 @@ import itertools
import torch import torch
def is_cuda_enabled():
return torch.version.cuda is not None
def get_cuda_version(): def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.')) return tuple(int(x) for x in torch.version.cuda.split('.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment