"...python/git@developer.sourcefind.cn:change/sglang.git" did not exist on "d40846d456ecc930c04538778ed11f67cc793c23"
Commit 03421e87 authored by Timothee Cour's avatar Timothee Cour Committed by mcarilli
Browse files
parent 3ae89c75
...@@ -74,10 +74,13 @@ if version_num < 1.1: ...@@ -74,10 +74,13 @@ if version_num < 1.1:
_bmms = ['addbmm', _bmms = ['addbmm',
'baddbmm', 'baddbmm',
'bmm'] 'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms) if utils.is_cuda_enabled():
else: # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
FP32_FUNCS.extend(_bmms) if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion # Multi-tensor fns that may need type promotion
CASTS = [ CASTS = [
......
...@@ -5,6 +5,9 @@ import itertools ...@@ -5,6 +5,9 @@ import itertools
import torch import torch
def is_cuda_enabled():
return torch.version.cuda is not None
def get_cuda_version(): def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.')) return tuple(int(x) for x in torch.version.cuda.split('.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment