Commit a92c6297 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in flatten parameters

parent 9954c186
...@@ -21,6 +21,9 @@ def check_param_device(params): ...@@ -21,6 +21,9 @@ def check_param_device(params):
assert device == params[i].device assert device == params[i].device
def pad_numel(numel, multiplier=2):
return (numel + multiplier - 1) // multiplier * multiplier
class _FP16OptimizerMixin(object): class _FP16OptimizerMixin(object):
def __init__(self, args, **kwargs): def __init__(self, args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order) # forward __init__ call to the next class in mro(method resolution order)
...@@ -53,7 +56,7 @@ class _FP16OptimizerMixin(object): ...@@ -53,7 +56,7 @@ class _FP16OptimizerMixin(object):
flatten_params = {} flatten_params = {}
for dtype in dtype_grouped_params: for dtype in dtype_grouped_params:
cur_params = dtype_grouped_params[dtype] cur_params = dtype_grouped_params[dtype]
total_param_size = sum(p.data.numel() for p in cur_params) total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_params[dtype] = ( flatten_params[dtype] = (
cur_params[0].new(0).type(dtype).new(total_param_size) cur_params[0].new(0).type(dtype).new(total_param_size)
) )
...@@ -64,7 +67,7 @@ class _FP16OptimizerMixin(object): ...@@ -64,7 +67,7 @@ class _FP16OptimizerMixin(object):
p.data = ( p.data = (
flatten_params[dtype].data[offset : offset + numel].view(*p.shape) flatten_params[dtype].data[offset : offset + numel].view(*p.shape)
) )
offset += numel offset += pad_numel(numel)
flatten_params[dtype] = torch.nn.Parameter(flatten_params[dtype]) flatten_params[dtype] = torch.nn.Parameter(flatten_params[dtype])
flatten_params[dtype].grad = flatten_params[dtype].data.new( flatten_params[dtype].grad = flatten_params[dtype].data.new(
total_param_size total_param_size
...@@ -75,7 +78,7 @@ class _FP16OptimizerMixin(object): ...@@ -75,7 +78,7 @@ class _FP16OptimizerMixin(object):
p.grad = ( p.grad = (
flatten_params[dtype].grad[offset : offset + numel].view(*p.shape) flatten_params[dtype].grad[offset : offset + numel].view(*p.shape)
) )
offset += numel offset += pad_numel(numel)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return list(flatten_params.values()) return list(flatten_params.values())
...@@ -119,7 +122,7 @@ class _FP16OptimizerMixin(object): ...@@ -119,7 +122,7 @@ class _FP16OptimizerMixin(object):
self.fp32_params.grad.data[offset : offset + numel].copy_( self.fp32_params.grad.data[offset : offset + numel].copy_(
p.grad.data.view(-1) p.grad.data.view(-1)
) )
offset += numel offset += pad_numel(numel)
self._needs_sync = False self._needs_sync = False
def _add_fp16_grads_to_fp32(self, mul=0.0): def _add_fp16_grads_to_fp32(self, mul=0.0):
...@@ -131,7 +134,7 @@ class _FP16OptimizerMixin(object): ...@@ -131,7 +134,7 @@ class _FP16OptimizerMixin(object):
offset : offset + numel offset : offset + numel
] += mul * p.grad.data.float().view(-1) ] += mul * p.grad.data.float().view(-1)
p.grad.zero_() p.grad.zero_()
offset += numel offset += pad_numel(numel)
self._needs_sync = False self._needs_sync = False
def _sync_fp32_params_to_fp16(self): def _sync_fp32_params_to_fp16(self):
...@@ -144,7 +147,7 @@ class _FP16OptimizerMixin(object): ...@@ -144,7 +147,7 @@ class _FP16OptimizerMixin(object):
utils.fp32_to_bf16_sr(u, p) utils.fp32_to_bf16_sr(u, p)
else: else:
p.data.copy_(u) p.data.copy_(u)
offset += numel offset += pad_numel(numel)
def _unscale_grads(self): def _unscale_grads(self):
self._sync_fp16_grads_to_fp32() self._sync_fp16_grads_to_fp32()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment