Commit 69907b80 authored by Michael Carilli's avatar Michael Carilli
Browse files

Lazily trigger fused_norm for FusedAdam error on Windows

parent 1b903852
......@@ -2,17 +2,27 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import ctypes
lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float
def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough
return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata,
input._cdata, 16384)
else:
return input.norm()
stashed_err = None
try:
lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
lib.THCudaHalfTensor_normall.restype = ctypes.c_float
def fused_norm(input):
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough
return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata,
input._cdata, 16384)
else:
return input.norm()
except TypeError as err:
stashed_err = err
def fused_norm(input):
raise RuntimeError("Failed to create fused_norm. This may happen on Windows "
"because of lib = ctypes.cdll.LoadLibrary(None): you can't "
"LoadLibrary with None. Original exception message was ",
stashed_err)
class FP16_Optimizer(object):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment