"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "0599a7b8410dc5cfdb477900b280475ae775d7f9"
Unverified Commit 89ac1b7b authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

flatten parameters (#4)

flatten memories
parent 1e8b6e33
...@@ -13,6 +13,14 @@ from unicore import utils ...@@ -13,6 +13,14 @@ from unicore import utils
from .dynamic_loss_scaler import DynamicLossScaler from .dynamic_loss_scaler import DynamicLossScaler
def check_param_device(params):
if len(params) <= 0:
return True
device = params[0].device
for i in range(1, len(params)):
assert device == params[i].device
class _FP16OptimizerMixin(object): class _FP16OptimizerMixin(object):
def __init__(self, args, **kwargs): def __init__(self, args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order) # forward __init__ call to the next class in mro(method resolution order)
...@@ -23,25 +31,53 @@ class _FP16OptimizerMixin(object): ...@@ -23,25 +31,53 @@ class _FP16OptimizerMixin(object):
@classmethod @classmethod
def build_fp32_params(cls, args, params): def build_fp32_params(cls, args, params):
# create FP32 copy of parameters and grads # create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params) total_param_size = sum([p.data.numel() for p in params])
devices = [torch.cuda.current_device()] fp32_params = params[0].new(0).float().new(total_param_size)
fp32_params = {} offset = 0
for device in devices: for p in params:
device_param_size = total_param_size numel = p.data.numel()
device_params = params fp32_params[offset : offset + numel].copy_(p.data.view(-1))
fp32_params[device] = ( offset += numel
device_params[0].new(0).float().new(device_param_size) fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
return fp32_params
@classmethod
def flatten_fp16_parameters(cls, args, params):
dtype_grouped_params = {}
for p in params:
if p.dtype not in dtype_grouped_params:
dtype_grouped_params[p.dtype] = []
dtype_grouped_params[p.dtype].append(p)
flatten_params = {}
for dtype in dtype_grouped_params:
cur_params = dtype_grouped_params[dtype]
total_param_size = sum(p.data.numel() for p in cur_params)
flatten_params[dtype] = (
cur_params[0].new(0).type(dtype).new(total_param_size)
) )
offset = 0 offset = 0
for p in device_params: for p in cur_params:
numel = p.data.numel() numel = p.data.numel()
fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) flatten_params[dtype][offset : offset + numel].copy_(p.data.view(-1))
p.data = (
flatten_params[dtype].data[offset : offset + numel].view(*p.shape)
)
offset += numel offset += numel
fp32_params[device] = torch.nn.Parameter(fp32_params[device]) flatten_params[dtype] = torch.nn.Parameter(flatten_params[dtype])
fp32_params[device].grad = fp32_params[device].data.new( flatten_params[dtype].grad = flatten_params[dtype].data.new(
device_param_size total_param_size
) )
return fp32_params offset = 0
for p in cur_params:
numel = p.data.numel()
p.grad = (
flatten_params[dtype].grad[offset : offset + numel].view(*p.shape)
)
offset += numel
torch.cuda.empty_cache()
return list(flatten_params.values())
def state_dict(self): def state_dict(self):
"""Return the optimizer's state dict.""" """Return the optimizer's state dict."""
...@@ -77,55 +113,33 @@ class _FP16OptimizerMixin(object): ...@@ -77,55 +113,33 @@ class _FP16OptimizerMixin(object):
def _sync_fp16_grads_to_fp32(self): def _sync_fp16_grads_to_fp32(self):
with torch.no_grad(): with torch.no_grad():
if self._needs_sync: if self._needs_sync:
devices = list(self.fp32_params.keys())
device_params_dict = defaultdict(list)
for p in self.fp16_params:
if p.requires_grad:
device_params_dict[p.device.index].append(p)
for device in devices:
device_params = device_params_dict[device]
offset = 0 offset = 0
for p in device_params: for p in self.fp16_params:
numel = p.numel() numel = p.numel()
if p.grad is not None: self.fp32_params.grad.data[offset : offset + numel].copy_(
self.fp32_params[device].grad.data[ p.grad.data.view(-1)
offset : offset + numel )
].copy_(p.grad.data.view(-1))
offset += numel offset += numel
self._needs_sync = False self._needs_sync = False
def _add_fp16_grads_to_fp32(self, mul=0.0): def _add_fp16_grads_to_fp32(self, mul=0.0):
with torch.no_grad(): with torch.no_grad():
devices = list(self.fp32_params.keys())
device_params_dict = defaultdict(list)
for p in self.fp16_params:
if p.requires_grad:
device_params_dict[p.device.index].append(p)
for device in devices:
device_params = device_params_dict[device]
offset = 0 offset = 0
for p in device_params: for p in self.fp16_params:
numel = p.numel() numel = p.numel()
if p.grad is not None: self.fp32_params.grad.data[
self.fp32_params[device].grad.data[
offset : offset + numel offset : offset + numel
] += mul * p.grad.data.float().view(-1) ] += mul * p.grad.data.float().view(-1)
p.grad = None p.grad.zero_()
offset += numel offset += numel
self._needs_sync = False self._needs_sync = False
def _sync_fp32_params_to_fp16(self): def _sync_fp32_params_to_fp16(self):
# copy FP32 params back into FP16 model # copy FP32 params back into FP16 model
devices = list(self.fp32_params.keys())
device_params_dict = defaultdict(list)
for p in self.fp16_params:
device_params_dict[p.device.index].append(p)
for device in devices:
device_params = device_params_dict[device]
offset = 0 offset = 0
for p in device_params: for p in self.fp16_params:
numel = p.data.numel() numel = p.numel()
u = self.fp32_params[device].data[offset : offset + numel].view_as(p.data) u = self.fp32_params.data[offset : offset + numel].view_as(p.data)
if self.bf16_sr and p.dtype == torch.bfloat16: if self.bf16_sr and p.dtype == torch.bfloat16:
utils.fp32_to_bf16_sr(u, p) utils.fp32_to_bf16_sr(u, p)
else: else:
...@@ -159,7 +173,9 @@ class _FP16OptimizerMixin(object): ...@@ -159,7 +173,9 @@ class _FP16OptimizerMixin(object):
"""Clips gradient norm.""" """Clips gradient norm."""
if max_norm <= 0.0: if max_norm <= 0.0:
return 0.0 return 0.0
grad_norm = self._multiply_factor * utils.clip_grad_norm_(self.fp16_params, 0, aggregate_norm_fn) grad_norm = self._multiply_factor * utils.clip_grad_norm_(
self.fp16_params, 0, aggregate_norm_fn
)
# grad_norm = 1.0 # grad_norm = 1.0
if grad_norm > max_norm > 0.0: if grad_norm > max_norm > 0.0:
clip_coef = max_norm / (grad_norm + 1e-6) clip_coef = max_norm / (grad_norm + 1e-6)
...@@ -167,12 +183,12 @@ class _FP16OptimizerMixin(object): ...@@ -167,12 +183,12 @@ class _FP16OptimizerMixin(object):
clip_coef = 1.0 clip_coef = 1.0
self._add_fp16_grads_to_fp32(mul=clip_coef) self._add_fp16_grads_to_fp32(mul=clip_coef)
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm and updates dynamic loss scaler.""" """Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32() self._sync_fp16_grads_to_fp32()
grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
0, aggregate_norm_fn=aggregate_norm_fn, 0,
aggregate_norm_fn=aggregate_norm_fn,
) )
if self.scaler is not None: if self.scaler is not None:
...@@ -190,7 +206,9 @@ class _FP16OptimizerMixin(object): ...@@ -190,7 +206,9 @@ class _FP16OptimizerMixin(object):
"""Performs a single optimization step.""" """Performs a single optimization step."""
self._sync_fp16_grads_to_fp32() self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False): if getattr(self, "supports_step_with_scale", False):
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) self.fp32_optimizer.step(
closure, scale=(1.0 / self._multiply_factor), groups=groups
)
else: else:
self._unscale_grads() self._unscale_grads()
self.fp32_optimizer.step(closure, groups=groups) self.fp32_optimizer.step(closure, groups=groups)
...@@ -203,7 +221,7 @@ class _FP16OptimizerMixin(object): ...@@ -203,7 +221,7 @@ class _FP16OptimizerMixin(object):
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
for p in self.fp16_params: for p in self.fp16_params:
p.grad = None p.grad.zero_()
if torch.is_tensor(self.fp32_params): if torch.is_tensor(self.fp32_params):
self.fp32_params.grad.zero_() self.fp32_params.grad.zero_()
elif isinstance(self.fp32_params, dict): elif isinstance(self.fp32_params, dict):
...@@ -237,12 +255,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer): ...@@ -237,12 +255,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
"--fp16-scale-window must be given explicitly when using a " "--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule" "custom --update-freq schedule"
) )
data_parallel_size = int( data_parallel_size = int(args.distributed_world_size)
args.distributed_world_size scale_window = int(2**14 / data_parallel_size / args.update_freq[0])
)
scale_window = int(
2 ** 14 / data_parallel_size / args.update_freq[0]
)
else: else:
scale_window = args.fp16_scale_window scale_window = args.fp16_scale_window
...@@ -267,6 +281,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer): ...@@ -267,6 +281,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
""" """
flatten = not getattr(args, "fp16_no_flatten_grads", False) flatten = not getattr(args, "fp16_no_flatten_grads", False)
assert flatten assert flatten
check_param_device(params)
params = cls.flatten_fp16_parameters(args, params)
fp32_params = cls.build_fp32_params(args, params) fp32_params = cls.build_fp32_params(args, params)
fp32_optimizer = optim.build_optimizer(args, [fp32_params]) fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params, **kwargs) return cls(args, params, fp32_optimizer, fp32_params, **kwargs)
...@@ -297,7 +313,7 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer): ...@@ -297,7 +313,7 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
if self.allreduce_fp32_grad and hasattr(module, "all_reduce_params"): if self.allreduce_fp32_grad and hasattr(module, "all_reduce_params"):
self._sync_fp16_grads_to_fp32() self._sync_fp16_grads_to_fp32()
with torch.no_grad(): with torch.no_grad():
params = [p for p in self.fp32_optimizer.params] params = [self.fp32_params]
module.all_reduce_params(params) module.all_reduce_params(params)
else: else:
self.fp32_optimizer.all_reduce_grads(module) self.fp32_optimizer.all_reduce_grads(module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment