Unverified Commit dd03cff2 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

set adamw_mode default true (follows FusedAdam and < 0.3.11 logic) (#844)

parent 564eb4bd
...@@ -74,7 +74,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -74,7 +74,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
self.opt_id = DeepSpeedCPUAdam.optimizer_id self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
self.adam_w_mode = adamw_mode
self.ds_opt_adam = CPUAdamBuilder().load() self.ds_opt_adam = CPUAdamBuilder().load()
self.ds_opt_adam.create_adam(self.opt_id, self.ds_opt_adam.create_adam(self.opt_id,
......
...@@ -40,6 +40,10 @@ DEEPSPEED_OPTIMIZERS = [ ...@@ -40,6 +40,10 @@ DEEPSPEED_OPTIMIZERS = [
# extra optimizer parameters for adam/adamw # extra optimizer parameters for adam/adamw
TORCH_ADAM_PARAM = "torch_adam" TORCH_ADAM_PARAM = "torch_adam"
# default to adamw logic for adam/adamw optimizers unless user explictly opts out
ADAM_W_MODE = "adam_w_mode"
ADAM_W_MODE_DEFAULT = True
class DeepSpeedConfigError(Exception): class DeepSpeedConfigError(Exception):
pass pass
......
...@@ -22,7 +22,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer ...@@ -22,7 +22,7 @@ from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \ ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
TORCH_ADAM_PARAM TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT
from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \ from deepspeed.runtime.constants import \
...@@ -640,26 +640,30 @@ class DeepSpeedEngine(Module): ...@@ -640,26 +640,30 @@ class DeepSpeedEngine(Module):
if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]: if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)
# zero-offload torch-adam adam_w_mode optimizer
# T|F T T torch.optim.AdamW # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explictly set
# T|F T F torch.optim.Adam effective_adam_w_mode = self.optimizer_name(
# T F T|F DeepSpeedCPUAdam(adam_w_mode) ) == ADAMW_OPTIMIZER or adam_w_mode
# F F T|F FusedAdam(adam_w_mode)
if torch_adam: if torch_adam:
if adam_w_mode: if not effective_adam_w_mode:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
else:
optimizer = torch.optim.Adam(model_parameters, optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters) **optimizer_parameters)
elif self.zero_cpu_offload(): else:
optimizer = DeepSpeedCPUAdam(model_parameters, optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters, **optimizer_parameters)
adamw_mode=adam_w_mode)
else: else:
optimizer_parameters['adam_w_mode'] = adam_w_mode if self.zero_cpu_offload():
optimizer = FusedAdam(model_parameters, **optimizer_parameters) from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import FusedAdam
optimizer = FusedAdam(model_parameters,
**optimizer_parameters,
adam_w_mode=effective_adam_w_mode)
elif self.optimizer_name() == LAMB_OPTIMIZER: elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb from deepspeed.ops.lamb import FusedLamb
......
import deepspeed
import torch
import pytest
from common import distributed_test
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from simple_model import SimpleModel, args_from_dict
# yapf: disable
#'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer
adam_configs = [["AdamW", False, False, False, (FusedAdam, True)],
["AdamW", False, True, False, (torch.optim.AdamW, None)],
["AdamW", True, False, False, (DeepSpeedCPUAdam, True)],
["AdamW", True, True, False, (torch.optim.AdamW, None)],
["AdamW", False, False, True, (FusedAdam, True)],
["AdamW", False, True, True, (torch.optim.AdamW, None)],
["AdamW", True, False, True, (DeepSpeedCPUAdam, True)],
["AdamW", True, True, True, (torch.optim.AdamW, None)],
["Adam", False, False, False, (FusedAdam, False)],
["Adam", False, True, False, (torch.optim.Adam, None)],
["Adam", True, False, False, (DeepSpeedCPUAdam, False)],
["Adam", True, True, False, (torch.optim.Adam, None)],
["Adam", False, False, True, (FusedAdam, True)],
["Adam", False, True, True, (torch.optim.AdamW, None)],
["Adam", True, False, True, (DeepSpeedCPUAdam, True)],
["Adam", True, True, True, (torch.optim.AdamW, None)]]
@pytest.mark.parametrize(
'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer',
adam_configs)
def test_adam_configs(tmpdir,
optimizer,
zero_offload,
torch_adam,
adam_w_mode,
resulting_optimizer):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": optimizer,
"params": {
"lr": 0.00015,
"torch_adam": torch_adam,
"adam_w_mode": adam_w_mode
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 2,
"cpu_offload": zero_offload
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[1])
def helper(args):
model = SimpleModel(10)
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# get base optimizer under zero
ds_optimizer = model.optimizer.optimizer
opt_class, adam_w_mode = resulting_optimizer
assert isinstance(ds_optimizer, opt_class)
if adam_w_mode in [True, False]:
assert ds_optimizer.adam_w_mode == adam_w_mode
helper(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment