Commit 41a9c708 authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Changed prefetching.

parent 44d68ff2
...@@ -100,7 +100,10 @@ class GlobalPageManager: ...@@ -100,7 +100,10 @@ class GlobalPageManager:
return cls._instance return cls._instance
def prefetch_all(self, to_cpu=False): def prefetch_all(self, to_cpu=False):
for t in self.paged_tensors: # assume the first added, will be hte
# ones that are used first, so swap them in last
# in the case they are evicted again
for t in self.paged_tensors[::-1]:
prefetch_tensor(t, to_cpu) prefetch_tensor(t, to_cpu)
......
...@@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -256,7 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True self.initialized = True
if self.is_paged: self.page_mng.prefetch_all() #if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
...@@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -265,7 +265,9 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(state) == 0: if len(state) == 0:
self.init_state(group, p, gindex, pindex) self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
if self.is_paged: if self.is_paged:
# all paged operation are asynchronous, we need # all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state # to sync to make sure all tensors are in the right state
...@@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -309,6 +311,13 @@ class Optimizer8bit(torch.optim.Optimizer):
self.page_mng.paged_tensors.append(buff) self.page_mng.paged_tensors.append(buff)
return buff return buff
def prefetch_state(self, p):
if self.is_paged:
state = self.state[p]
F.prefetch_tensor(state['state1'])
if 'state2' in state:
F.prefetch_tensor(state['state2'])
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment