assert_viewless_tensor(tensor,extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."%("--"iftensor._baseisNoneelsetensor._base.shape,new_data_tensor.shape))
assert_viewless_tensor(tensor,extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."%("--"iftensor._baseisNoneelsetensor._base.shape,new_data_tensor.shape))
tensor.data=new_data_tensor
tensor.data=new_data_tensor
# <<<
classCudaRNGStatesTracker:
classCudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
"""Tracker for the cuda RNG states.
...
@@ -328,19 +288,10 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -328,19 +288,10 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# Divide hidden states across model parallel group and only keep