Commit 35ce337b authored by Egor Krivov's avatar Egor Krivov
Browse files

Fixed bugs

parent b43edf56
...@@ -579,7 +579,7 @@ def _optimizer_update_8bit_blockwise_impl( ...@@ -579,7 +579,7 @@ def _optimizer_update_8bit_blockwise_impl(
g: torch.Tensor, g: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
state1: torch.Tensor, state1: torch.Tensor,
state2: Optional[torch.nsor], state2: Optional[torch.Tensor],
beta1: float, beta1: float,
beta2: float, beta2: float,
beta3: float, beta3: float,
......
...@@ -280,6 +280,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -280,6 +280,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.initialized = True self.initialized = True
# if self.is_paged: self.page_mng.prefetch_all() # if self.is_paged: self.page_mng.prefetch_all()
p = None
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
...@@ -291,10 +292,10 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -291,10 +292,10 @@ class Optimizer8bit(torch.optim.Optimizer):
self.prefetch_state(p) self.prefetch_state(p)
self.update_step(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex)
sync_gpu(p) sync_gpu(p)
if self.is_paged: if self.is_paged and p is not None:
# all paged operations are asynchronous, we need # all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state # to sync to make sure all tensors are in the right state
sync_gpu(loss) sync_gpu(p)
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment