Commit 4cc1c1b4 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding small python wrapper for multi_tensor_apply

parent 6763a8be
from .multi_tensor_apply import MultiTensorApply
multi_tensor_applier = MultiTensorApply(2048*32)
import torch
class MultiTensorApply(object):
available = False
warned = False
def __init__(self, chunk_size):
try:
import amp_C
MultiTensorApply.available = True
self.chunk_size = chunk_size
except ImportError as err:
MultiTensorApply.availble = False
MultiTensorApply.import_err = err
def check_avail(self):
if MultiTensorApply.available == False:
raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:",
MultiTensorApply.import_err)
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
self.check_avail()
return op(self.chunk_size,
noop_flag_buffer,
tensor_lists,
*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment