Unverified Commit 5212a0f2 authored by Titus's avatar Titus Committed by GitHub
Browse files

Edenzzzz's fix for min_8bit_size functionality in Optimizer base classes (#1286)



* fix min_8bit_size invalid bug

* Apply same fix to other optimizer base class

---------
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
parent 0bdd57cc
...@@ -437,7 +437,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -437,7 +437,7 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): if dtype == torch.float32:
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8: elif dtype == torch.uint8:
...@@ -656,7 +656,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -656,7 +656,7 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): if dtype == torch.float32:
state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment