Commit a92c6297 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in flatten parameters

parent 9954c186
......@@ -21,6 +21,9 @@ def check_param_device(params):
assert device == params[i].device
def pad_numel(numel, multiplier=2):
return (numel + multiplier - 1) // multiplier * multiplier
class _FP16OptimizerMixin(object):
def __init__(self, args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
......@@ -53,7 +56,7 @@ class _FP16OptimizerMixin(object):
flatten_params = {}
for dtype in dtype_grouped_params:
cur_params = dtype_grouped_params[dtype]
total_param_size = sum(p.data.numel() for p in cur_params)
total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_params[dtype] = (
cur_params[0].new(0).type(dtype).new(total_param_size)
)
......@@ -64,7 +67,7 @@ class _FP16OptimizerMixin(object):
p.data = (
flatten_params[dtype].data[offset : offset + numel].view(*p.shape)
)
offset += numel
offset += pad_numel(numel)
flatten_params[dtype] = torch.nn.Parameter(flatten_params[dtype])
flatten_params[dtype].grad = flatten_params[dtype].data.new(
total_param_size
......@@ -75,7 +78,7 @@ class _FP16OptimizerMixin(object):
p.grad = (
flatten_params[dtype].grad[offset : offset + numel].view(*p.shape)
)
offset += numel
offset += pad_numel(numel)
torch.cuda.empty_cache()
return list(flatten_params.values())
......@@ -119,7 +122,7 @@ class _FP16OptimizerMixin(object):
self.fp32_params.grad.data[offset : offset + numel].copy_(
p.grad.data.view(-1)
)
offset += numel
offset += pad_numel(numel)
self._needs_sync = False
def _add_fp16_grads_to_fp32(self, mul=0.0):
......@@ -131,7 +134,7 @@ class _FP16OptimizerMixin(object):
offset : offset + numel
] += mul * p.grad.data.float().view(-1)
p.grad.zero_()
offset += numel
offset += pad_numel(numel)
self._needs_sync = False
def _sync_fp32_params_to_fp16(self):
......@@ -144,7 +147,7 @@ class _FP16OptimizerMixin(object):
utils.fp32_to_bf16_sr(u, p)
else:
p.data.copy_(u)
offset += numel
offset += pad_numel(numel)
def _unscale_grads(self):
self._sync_fp16_grads_to_fp32()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment