Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor,extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."%("--"iftensor._baseisNoneelsetensor._base.shape,new_data_tensor.shape))
tensor.data=new_data_tensor
classCudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def__init__(self):
# Map from a string name to the cuda rng state.
self.states_={}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_=set()
defreset(self):
"""Set to the initial state (no tracker)."""
self.states_={}
self.seeds_=set()
defget_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states={}
fornameinself.states_:
states[name]=self.states_[name]
returnstates
defset_states(self,states):
"""Set the rng states. For efficiency purposes, we do not check