Unverified Commit a8c2b7dd authored by ngimel's avatar ngimel Committed by GitHub
Browse files

Merge pull request #225 from NVIDIA/bmm-fp16

Conditionally run bmm functions in fp16 based on cuda version
parents f5cd5ae9 f1123e32
import torch
from .. import utils
MODULE = torch
FP16_FUNCS = [
......@@ -20,10 +22,8 @@ FP16_FUNCS = [
'matmul',
'mm',
'mv',
]
# TODO: ban in-place versions of these in fp16
FP32_FUNCS = [
# Pointwise
'acos',
......@@ -54,15 +54,21 @@ FP32_FUNCS = [
'sum',
'var',
# Special reduction-like BLAS
'addbmm',
'baddbmm',
'bmm',
# Misc
'renorm'
]
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
'baddbmm',
'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion
CASTS = [
# Multi-tensor math
......@@ -87,8 +93,9 @@ CASTS = [
'ne'
]
# Will possibly need to promote *all* elements of `seq`
# Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS = [
'cat', # torch.cat(seq, dim=0, out=None)
'stack' # torch.stack(seq, dim=0, out=None)
'cat',
'stack'
]
......@@ -5,6 +5,9 @@ import itertools
import torch
def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.'))
def is_fp_tensor(x):
if is_nested(x):
# Fast-fail version of all(is_fp_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment