Unverified Commit abe2204d authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Support fp32 grad clipping and fix max_grad_norm confusion (#232)

* updates to support fp32 grad clipping and disable max_grad_norm
parent 6a9d57f6
......@@ -132,10 +132,6 @@ def get_dump_state(param_dict):
def get_gradient_clipping(param_dict):
grad_clip = get_optimizer_gradient_clipping(param_dict)
if grad_clip is not None:
return grad_clip
else:
return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
......@@ -443,9 +439,6 @@ class DeepSpeedConfig(object):
def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled
if self.gradient_clipping > 0. and not fp16_enabled:
logging.warning(
'DeepSpeedConfig: gradient clipping enabled without FP16 enabled.')
vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
......
......@@ -123,7 +123,7 @@ FP16_MIN_LOSS_SCALE_DEFAULT = 1
# Gradient clipping. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
GRADIENT_CLIPPING_FORMAT = '''
Dump state should be enabled as:
Gradient clipping should be enabled as:
"gradient_clipping": 1.0
'''
GRADIENT_CLIPPING = 'gradient_clipping'
......
......@@ -503,11 +503,10 @@ class DeepSpeedLight(Module):
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
if self.fp16_enabled() and 'max_grad_norm' in optimizer_parameters.keys():
warnings.warn(
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)
optimizer_parameters['max_grad_norm'] = 0.0
if self.optimizer_name() == ADAM_OPTIMIZER:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
......@@ -778,6 +777,10 @@ class DeepSpeedLight(Module):
for param_name, param in self.module.named_parameters():
param.grad = None
def clip_fp32_gradients(self):
torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(),
max_norm=self.gradient_clipping())
def step(self):
r"""Execute the weight update step after forward and backward propagation on effective_train_batch
"""
......@@ -790,6 +793,10 @@ class DeepSpeedLight(Module):
report_progress = self.global_rank == 0 if self.global_rank else True
if self.is_gradient_accumulation_boundary():
if not self.fp16_enabled() and self.gradient_clipping() > 0.0:
self.clip_fp32_gradients()
self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit
......
......@@ -8,11 +8,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -11,11 +11,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -8,11 +8,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -8,11 +8,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -11,11 +11,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -8,11 +8,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"scheduler": {
"type": "WarmupLR",
"params": {
......
......@@ -7,11 +7,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -9,11 +9,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -7,11 +7,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
......
......@@ -4,11 +4,10 @@
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0
......
......@@ -136,10 +136,10 @@ def test_checkpoint_unfused_optimizer(tmpdir):
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True
},
......
......@@ -8,6 +8,40 @@ from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
def test_lamb_fp32_grad_clip(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1, 2])
def _test_lamb_fp32_grad_clip(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0].float(), batch[1])
model.backward(loss)
model.step()
_test_lamb_fp32_grad_clip(args=args, model=model, hidden_dim=hidden_dim)
def test_lamb_fp16_basic(tmpdir):
config_dict = {
"train_batch_size": 2,
......@@ -15,10 +49,10 @@ def test_lamb_fp16_basic(tmpdir):
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True
}
......@@ -52,10 +86,10 @@ def test_lamb_fp16_empty_grad(tmpdir):
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment