Commit 69907b80 authored by Michael Carilli's avatar Michael Carilli
Browse files

Lazily trigger fused_norm for FusedAdam error on Windows

parent 1b903852
...@@ -2,17 +2,27 @@ import torch ...@@ -2,17 +2,27 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import ctypes import ctypes
lib = ctypes.cdll.LoadLibrary(None) stashed_err = None
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p] try:
lib.THCudaHalfTensor_normall.restype = ctypes.c_float lib = ctypes.cdll.LoadLibrary(None)
lib.THCudaHalfTensor_normall.argtypes=[ctypes.c_void_p, ctypes.c_void_p]
def fused_norm(input): lib.THCudaHalfTensor_normall.restype = ctypes.c_float
if input.type() == 'torch.cuda.HalfTensor':
# 16384 is half 2 if you stare at it long enough def fused_norm(input):
return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata, if input.type() == 'torch.cuda.HalfTensor':
input._cdata, 16384) # 16384 is half 2 if you stare at it long enough
else: return lib.THCudaHalfTensor_normall(torch.cuda._state_cdata,
return input.norm() input._cdata, 16384)
else:
return input.norm()
except TypeError as err:
stashed_err = err
def fused_norm(input):
raise RuntimeError("Failed to create fused_norm. This may happen on Windows "
"because of lib = ctypes.cdll.LoadLibrary(None): you can't "
"LoadLibrary with None. Original exception message was ",
stashed_err)
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment