Unverified Commit a76572dc authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Adding static loss scaling for ZeRO. (#166)

parent 012d91df
......@@ -16,3 +16,5 @@ docs/code-docs/build
.sass-cache/
.jekyll-cache/
.jekyll-metadata
tests/unit/saved_checkpoint/
......@@ -12,7 +12,7 @@ import torch.distributed as dist
import math
from torch._six import inf
from deepspeed.pt.loss_scaler import DynamicLossScaler
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
......@@ -175,15 +175,14 @@ class FP16_DeepSpeedZeroOptimizer(object):
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
self.dynamic_loss_scale = True
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
else:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
self.dynamic_loss_scale = True
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=static_loss_scale)
self.cur_iter = 0
self.mpu = mpu
......
......@@ -246,3 +246,46 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir):
_test_adam_fp16_zero_onecycle_compatibility(args=args,
model=model,
hidden_dim=hidden_dim)
def test_zero_static_scale(tmpdir):
config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True,
"loss_scale": 138.
},
"zero_optimization": True
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=2)
def _test_zero_static_scale(args):
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# Ensure the static scaler is configured.
assert optim.dynamic_loss_scale == False
assert optim.loss_scaler.loss_scale == 138.
# Now make sure things work..
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_zero_static_scale(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment