Unverified Commit 0cafa7f4 authored by youngrok cha's avatar youngrok cha Committed by GitHub
Browse files

Specify blocksize (#1586)

* [fix] define blocksize

define blocksize as just showing number is a bit confusing

* [fix] match code with ademamix

also define blocksize as in prev commit
parent 12c40963
......@@ -166,8 +166,9 @@ class AdEMAMix(Optimizer2State):
self.name2qmap["dynamic"] = state["qmap1"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)
blocksize = 256
n = p.numel()
blocks = (n // 256) + bool(n % 256)
blocks = (n // blocksize) + bool(n % blocksize)
state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
......
......@@ -475,9 +475,9 @@ class Optimizer2State(Optimizer8bit):
state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]:
blocksize = 256
n = p.numel()
blocks = n // 256
blocks += 1 if n % 256 > 0 else 0
blocks = (n // blocksize) + bool(n % blocksize)
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
......@@ -697,9 +697,9 @@ class Optimizer1State(Optimizer8bit):
state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]:
blocksize = 256
n = p.numel()
blocks = n // 256
blocks += 1 if n % 256 > 0 else 0
blocks = (n // blocksize) + bool(n % blocksize)
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment