Unverified Commit 0be026e3 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support dynamic loss scale args in fp16 optimizers (#212)

* Support dynamic loss scale args in fp16 optimizers

* Update names
parent b2c87edf
......@@ -7,6 +7,7 @@ import torch
import logging
import json
from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
......@@ -72,10 +73,10 @@ def get_dynamic_loss_scale_args(param_dict):
FP16_MIN_LOSS_SCALE,
FP16_MIN_LOSS_SCALE_DEFAULT)
loss_scale_args = {
'init_scale': 2**init_scale,
'scale_window': scale_window,
'delayed_shift': delayed_shift,
'min_scale': min_loss_scale
INITIAL_LOSS_SCALE: 2**init_scale,
SCALE_WINDOW: scale_window,
DELAYED_SHIFT: delayed_shift,
MIN_LOSS_SCALE: min_loss_scale
}
return loss_scale_args
......
......@@ -7,9 +7,11 @@ This file is adapted from FP16_Optimizer in NVIDIA/apex
import torch
import logging
import math
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
import math
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
class FP16_Optimizer(object):
......@@ -63,14 +65,19 @@ class FP16_Optimizer(object):
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
logging.warning("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = initial_dynamic_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
if dynamic_loss_args is None:
self.cur_scale = initial_dynamic_scale
self.scale_window = 1000
self.min_loss_scale = 1
else:
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
......@@ -126,8 +133,9 @@ class FP16_Optimizer(object):
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
......@@ -178,8 +186,9 @@ class FP16_Optimizer(object):
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
......@@ -235,15 +244,26 @@ class FP16_Optimizer(object):
def _update_scale(self, skip):
if self.dynamic_loss_scale:
prev_scale = self.cur_scale
if skip:
if self.verbose:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
self.cur_scale = max(self.cur_scale / self.scale_factor,
self.min_loss_scale)
self.last_overflow_iter = self.cur_iter
if self.verbose:
print(f"\nGrad overflow on iteration {self.cur_iter}")
print(
f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
# Ensure self.scale_window updates since last overflow
stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
self.cur_scale *= self.scale_factor
if self.verbose:
print(f"\nNo Grad overflow for {self.scale_window} iterations")
print(
f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
else:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
......
......@@ -7,10 +7,12 @@ This file is adapted from FP16_Optimizer in NVIDIA/apex
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
import math
import logging
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
class FP16_UnfusedOptimizer(object):
"""
......@@ -62,14 +64,18 @@ class FP16_UnfusedOptimizer(object):
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
raise SystemError("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = 1.0 * 2**16
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2.0
if dynamic_loss_args is None:
self.cur_scale = 1.0 * 2**16
self.scale_window = 1000
self.min_loss_scale = 0.25
else:
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
......@@ -128,8 +134,8 @@ class FP16_UnfusedOptimizer(object):
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
......@@ -153,8 +159,8 @@ class FP16_UnfusedOptimizer(object):
self.overflow = self.overflow_checker.check()
prev_scale = self.cur_scale
if self.overflow:
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
......@@ -224,14 +230,26 @@ class FP16_UnfusedOptimizer(object):
def _update_scale(self, skip):
if self.dynamic_loss_scale:
prev_scale = self.cur_scale
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, 0.25)
self.cur_scale = max(self.cur_scale / self.scale_factor,
self.min_loss_scale)
self.last_overflow_iter = self.cur_iter
if self.verbose:
print("\nGrad overflow on iteration", self.cur_iter)
print(
f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
# Ensure self.scale_window updates since last overflow
stable_interval = (self.cur_iter - self.last_overflow_iter) - 1
if (stable_interval > 0) and (stable_interval % self.scale_window == 0):
self.cur_scale *= self.scale_factor
if self.verbose:
print(f"\nNo Grad overflow for {self.scale_window} iterations")
print(
f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
)
else:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
......
......@@ -18,6 +18,11 @@
import torch
INITIAL_LOSS_SCALE = 'init_scale'
SCALE_WINDOW = 'scale_window'
DELAYED_SHIFT = 'delayed_shift'
MIN_LOSS_SCALE = 'min_scale'
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
......
import torch
import deepspeed
import argparse
import pytest
import json
import os
import numpy as np
from common import distributed_test
from simple_model import SimpleModel, args_from_dict
def run_model_step(model, gradient_list):
for value in gradient_list:
for p in model.parameters():
p.grad = torch.empty_like(p, dtype=p.dtype)
p.grad.fill_(value)
model.step()
def test_fused_no_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 8,
"loss_scale_window": 2
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _test_fused_no_overflow(args):
hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
expected_loss_scale = 2**8
expected_scale_window = 2
# Ensure the dynamic loss scaler is correctly configured.
assert optim.dynamic_loss_scale == True
assert optim.cur_scale == expected_loss_scale
assert optim.scale_window == expected_scale_window
for i, value in enumerate(np.random.uniform(-0.1, 0.1, 10)):
run_model_step(model, [value])
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == (i + 1)
if optim.cur_iter % expected_scale_window == 0:
expected_loss_scale *= 2
_test_fused_no_overflow(args)
def test_fused_all_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 4,
"loss_scale_window": 2
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _test_fused_all_overflow(args):
hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
expected_loss_scale = 2**4
# Ensure the dynamic loss scaler is correctly configured.
assert optim.dynamic_loss_scale == True
assert optim.cur_scale == expected_loss_scale
overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6
for i, value in enumerate(overflow_gradients):
run_model_step(model, [value])
expected_loss_scale = max(expected_loss_scale / 2, 1)
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == (i + 1)
_test_fused_all_overflow(args)
def test_fused_some_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 8,
"loss_scale_window": 2
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _test_fused_some_overflow(args):
hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
expected_loss_scale = 2**8
expected_scale_window = 2
expected_iteration = 0
# Ensure the dynamic loss scaler is correctly configured.
assert optim.dynamic_loss_scale == True
assert optim.cur_scale == expected_loss_scale
assert optim.scale_window == expected_scale_window
# Run model with overflows to decrease scale
overflow_gradients = [float('inf'), float('nan')]
expected_iteration += len(overflow_gradients)
run_model_step(model, overflow_gradients)
expected_loss_scale /= (2**len(overflow_gradients))
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == expected_iteration
# Run model scale_window + 1 times to increase scale once
normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1)
expected_iteration += len(normal_gradients)
run_model_step(model, normal_gradients)
expected_loss_scale *= 2
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == expected_iteration
# Run model with overflows to decrease scale
overflow_gradients = [float('inf')]
expected_iteration += len(overflow_gradients)
run_model_step(model, overflow_gradients)
expected_loss_scale /= (2**len(overflow_gradients))
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == expected_iteration
_test_fused_some_overflow(args)
def test_unfused_no_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 8,
"loss_scale_window": 2
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _test_unfused_no_overflow(args):
hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
expected_loss_scale = 2**8
expected_scale_window = 2
# Ensure the dynamic loss scaler is correctly configured.
assert optim.dynamic_loss_scale == True
assert optim.cur_scale == expected_loss_scale
assert optim.scale_window == expected_scale_window
for i, value in enumerate(np.random.uniform(-0.1, 0.1, 10)):
run_model_step(model, [value])
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == (i + 1)
if optim.cur_iter % expected_scale_window == 0:
expected_loss_scale *= 2
_test_unfused_no_overflow(args)
def test_unfused_all_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 4,
"loss_scale_window": 2,
"min_loss_scale": 0.25
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _test_unfused_all_overflow(args):
hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
expected_loss_scale = 2**4
expected_min_loss_scale = 0.25
# Ensure the dynamic loss scaler is correctly configured.
assert optim.dynamic_loss_scale == True
assert optim.cur_scale == expected_loss_scale
assert optim.min_loss_scale == expected_min_loss_scale
overflow_gradients = [float('inf'), float('-inf')] + [float('nan')] * 6
for i, value in enumerate(overflow_gradients):
run_model_step(model, [value])
expected_loss_scale = max(expected_loss_scale / 2, expected_min_loss_scale)
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == (i + 1)
_test_unfused_all_overflow(args)
def test_unfused_some_overflow(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 8,
"loss_scale_window": 2
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=1)
def _test_unfused_some_overflow(args):
hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
expected_loss_scale = 2**8
expected_scale_window = 2
expected_iteration = 0
# Ensure the dynamic loss scaler is correctly configured.
assert optim.dynamic_loss_scale == True
assert optim.cur_scale == expected_loss_scale
assert optim.scale_window == expected_scale_window
# Run model with overflows to decrease scale
overflow_gradients = [float('inf'), float('nan')]
expected_iteration += len(overflow_gradients)
run_model_step(model, overflow_gradients)
expected_loss_scale /= (2**len(overflow_gradients))
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == expected_iteration
# Run model scale_window + 1 times to increase scale once
normal_gradients = np.random.uniform(-0.1, 0.1, expected_scale_window + 1)
expected_iteration += len(normal_gradients)
run_model_step(model, normal_gradients)
expected_loss_scale *= 2
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == expected_iteration
# Run model with overflows to decrease scale
overflow_gradients = [float('inf')]
expected_iteration += len(overflow_gradients)
run_model_step(model, overflow_gradients)
expected_loss_scale /= (2**len(overflow_gradients))
assert optim.cur_scale == expected_loss_scale
assert optim.cur_iter == expected_iteration
_test_unfused_some_overflow(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment