Unverified Commit a8c2b7dd authored by ngimel's avatar ngimel Committed by GitHub
Browse files

Merge pull request #225 from NVIDIA/bmm-fp16

Conditionally run bmm functions in fp16 based on cuda version
parents f5cd5ae9 f1123e32
import torch import torch
from .. import utils
MODULE = torch MODULE = torch
FP16_FUNCS = [ FP16_FUNCS = [
...@@ -20,10 +22,8 @@ FP16_FUNCS = [ ...@@ -20,10 +22,8 @@ FP16_FUNCS = [
'matmul', 'matmul',
'mm', 'mm',
'mv', 'mv',
] ]
# TODO: ban in-place versions of these in fp16
FP32_FUNCS = [ FP32_FUNCS = [
# Pointwise # Pointwise
'acos', 'acos',
...@@ -54,15 +54,21 @@ FP32_FUNCS = [ ...@@ -54,15 +54,21 @@ FP32_FUNCS = [
'sum', 'sum',
'var', 'var',
# Special reduction-like BLAS
'addbmm',
'baddbmm',
'bmm',
# Misc # Misc
'renorm' 'renorm'
] ]
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
'baddbmm',
'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion # Multi-tensor fns that may need type promotion
CASTS = [ CASTS = [
# Multi-tensor math # Multi-tensor math
...@@ -87,8 +93,9 @@ CASTS = [ ...@@ -87,8 +93,9 @@ CASTS = [
'ne' 'ne'
] ]
# Will possibly need to promote *all* elements of `seq` # Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS = [ SEQUENCE_CASTS = [
'cat', # torch.cat(seq, dim=0, out=None) 'cat',
'stack' # torch.stack(seq, dim=0, out=None) 'stack'
] ]
...@@ -5,6 +5,9 @@ import itertools ...@@ -5,6 +5,9 @@ import itertools
import torch import torch
def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.'))
def is_fp_tensor(x): def is_fp_tensor(x):
if is_nested(x): if is_nested(x):
# Fast-fail version of all(is_fp_tensor) # Fast-fail version of all(is_fp_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment