Unverified Commit 865104be authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support optimizer AdamW type (#670)

parent f032e56f
...@@ -27,17 +27,18 @@ from ..profiling.config import DeepSpeedFlopsProfilerConfig ...@@ -27,17 +27,18 @@ from ..profiling.config import DeepSpeedFlopsProfilerConfig
TENSOR_CORE_ALIGN_SIZE = 8 TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam' ADAM_OPTIMIZER = 'adam'
ADAMW_OPTIMIZER = 'adamw'
LAMB_OPTIMIZER = 'lamb' LAMB_OPTIMIZER = 'lamb'
ONEBIT_ADAM_OPTIMIZER = 'onebitadam' ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
DEEPSPEED_OPTIMIZERS = [ DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER, ADAM_OPTIMIZER,
ADAMW_OPTIMIZER,
LAMB_OPTIMIZER, LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER,
] ]
# extra optimizer parameters for adam # extra optimizer parameters for adam/adamw
TORCH_ADAM_PARAM = "torch_adam" TORCH_ADAM_PARAM = "torch_adam"
ADAM_W_MODE_PARAM = "adam_w_mode"
class DeepSpeedConfigError(Exception): class DeepSpeedConfigError(Exception):
......
...@@ -19,8 +19,8 @@ from deepspeed.runtime.activation_checkpointing import checkpointing as activati ...@@ -19,8 +19,8 @@ from deepspeed.runtime.activation_checkpointing import checkpointing as activati
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \ ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE_PARAM TORCH_ADAM_PARAM
from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \ from deepspeed.runtime.constants import \
...@@ -582,10 +582,9 @@ class DeepSpeedEngine(Module): ...@@ -582,10 +582,9 @@ class DeepSpeedEngine(Module):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
) )
if self.optimizer_name() == ADAM_OPTIMIZER: if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE_PARAM, True) adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER
# zero-offload torch-adam adam_w_mode optimizer # zero-offload torch-adam adam_w_mode optimizer
# T|F T T torch.optim.AdamW # T|F T T torch.optim.AdamW
# T|F T F torch.optim.Adam # T|F T F torch.optim.Adam
...@@ -603,7 +602,7 @@ class DeepSpeedEngine(Module): ...@@ -603,7 +602,7 @@ class DeepSpeedEngine(Module):
**optimizer_parameters, **optimizer_parameters,
adamw_mode=adam_w_mode) adamw_mode=adam_w_mode)
else: else:
optimizer_parameters[ADAM_W_MODE_PARAM] = adam_w_mode optimizer_parameters['adam_w_mode'] = adam_w_mode
optimizer = FusedAdam(model_parameters, **optimizer_parameters) optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER: elif self.optimizer_name() == LAMB_OPTIMIZER:
......
...@@ -23,7 +23,12 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None): ...@@ -23,7 +23,12 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
return my_group return my_group
ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, FusedAdam, DeepSpeedCPUAdam] ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam,
torch.optim.AdamW,
FusedAdam,
DeepSpeedCPUAdam
]
# Add apex FusedAdam to supported list if apex is installed # Add apex FusedAdam to supported list if apex is installed
try: try:
......
...@@ -35,7 +35,7 @@ def test_lamb_fp32_grad_clip(tmpdir): ...@@ -35,7 +35,7 @@ def test_lamb_fp32_grad_clip(tmpdir):
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_lamb_fp32_grad_clip(args, model, hidden_dim): def _test_lamb_fp32_grad_clip(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -73,7 +73,7 @@ def test_lamb_fp16_basic(tmpdir): ...@@ -73,7 +73,7 @@ def test_lamb_fp16_basic(tmpdir):
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_lamb_fp16_basic(args, model, hidden_dim): def _test_lamb_fp16_basic(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -110,7 +110,7 @@ def test_lamb_fp16_empty_grad(tmpdir): ...@@ -110,7 +110,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_lamb_fp16_empty_grad(args, model, hidden_dim): def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -147,7 +147,7 @@ def test_adam_fp32_empty_grad(tmpdir): ...@@ -147,7 +147,7 @@ def test_adam_fp32_empty_grad(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_adam_fp32_empty_grad(args, model, hidden_dim): def _test_adam_fp32_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -179,7 +179,7 @@ def test_adamw_fp16_basic(tmpdir): ...@@ -179,7 +179,7 @@ def test_adamw_fp16_basic(tmpdir):
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adamw_fp16_basic(args, model, hidden_dim): def _test_adamw_fp16_basic(args, model, hidden_dim):
optimizer = torch.optim.AdamW(params=model.parameters()) optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
optimizer=optimizer) optimizer=optimizer)
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -210,7 +210,7 @@ def test_dict_config_adamw_fp16_basic(): ...@@ -210,7 +210,7 @@ def test_dict_config_adamw_fp16_basic():
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict): def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict):
optimizer = torch.optim.AdamW(params=model.parameters()) optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
config_params=config_dict) config_params=config_dict)
...@@ -245,7 +245,7 @@ def test_adamw_fp16_empty_grad(tmpdir): ...@@ -245,7 +245,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adamw_fp16_empty_grad(args, model, hidden_dim): def _test_adamw_fp16_empty_grad(args, model, hidden_dim):
optimizer = torch.optim.AdamW(params=model.parameters()) optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
optimizer=optimizer) optimizer=optimizer)
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -270,7 +270,7 @@ def test_adamw_fp16_empty_grad(tmpdir): ...@@ -270,7 +270,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
True), True),
]) ])
def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload): def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload):
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: # if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed") # pytest.skip("cpu-adam is not installed")
config_dict = { config_dict = {
"train_batch_size": 1, "train_batch_size": 1,
...@@ -311,7 +311,7 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo ...@@ -311,7 +311,7 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim): def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -338,7 +338,7 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo ...@@ -338,7 +338,7 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo
True), True),
]) ])
def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: # if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed") # pytest.skip("cpu-adam is not installed")
config_dict = { config_dict = {
"train_batch_size": 4, "train_batch_size": 4,
...@@ -364,7 +364,7 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): ...@@ -364,7 +364,7 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
def _test_zero_static_scale(args): def _test_zero_static_scale(args):
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _,_ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -407,7 +407,7 @@ def test_zero_static_scale_deprecated_format(tmpdir): ...@@ -407,7 +407,7 @@ def test_zero_static_scale_deprecated_format(tmpdir):
def _test_zero_static_scale(args): def _test_zero_static_scale(args):
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim, empty_grad=True)
model, optim, _,_ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -438,7 +438,7 @@ def test_zero_static_scale_deprecated_format(tmpdir): ...@@ -438,7 +438,7 @@ def test_zero_static_scale_deprecated_format(tmpdir):
True), True),
]) ])
def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: # if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed") # pytest.skip("cpu-adam is not installed")
config_dict = { config_dict = {
"train_batch_size": 4, "train_batch_size": 4,
...@@ -460,7 +460,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): ...@@ -460,7 +460,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim, empty_grad=True)
optimizer = SimpleOptimizer(model.parameters()) optimizer = SimpleOptimizer(model.parameters())
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
model, optim, _,_ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -478,7 +478,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): ...@@ -478,7 +478,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
True), True),
]) ])
def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload): def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload):
#if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']: # if use_cpu_offload and not deepspeed.ops.__installed_ops__['cpu-adam']:
# pytest.skip("cpu-adam is not installed") # pytest.skip("cpu-adam is not installed")
config_dict = { config_dict = {
"train_micro_batch_size_per_gpu": 1, "train_micro_batch_size_per_gpu": 1,
...@@ -536,7 +536,7 @@ def test_adam_amp_basic(tmpdir): ...@@ -536,7 +536,7 @@ def test_adam_amp_basic(tmpdir):
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adam_amp_basic(args, model, hidden_dim): def _test_adam_amp_basic(args, model, hidden_dim):
optimizer = torch.optim.Adam(params=model.parameters()) optimizer = torch.optim.Adam(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
optimizer=optimizer) optimizer=optimizer)
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -574,7 +574,7 @@ def test_lamb_amp_basic(tmpdir): ...@@ -574,7 +574,7 @@ def test_lamb_amp_basic(tmpdir):
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_lamb_amp_basic(args, model, hidden_dim): def _test_lamb_amp_basic(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -613,7 +613,7 @@ def test_adam_amp_o2(tmpdir): ...@@ -613,7 +613,7 @@ def test_adam_amp_o2(tmpdir):
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_adam_amp_o2(args, model, hidden_dim): def _test_adam_amp_o2(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -652,7 +652,7 @@ def test_adam_amp_o2_empty_grad(tmpdir): ...@@ -652,7 +652,7 @@ def test_adam_amp_o2_empty_grad(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_adam_amp_o2_empty_grad(args, model, hidden_dim): def _test_adam_amp_o2_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -732,7 +732,7 @@ def test_zero2_reduce_scatter_off(tmpdir): ...@@ -732,7 +732,7 @@ def test_zero2_reduce_scatter_off(tmpdir):
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _helper(args, model, hidden_dim): def _helper(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args, model, _, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
data_loader = random_dataloader(model=model, data_loader = random_dataloader(model=model,
...@@ -745,3 +745,53 @@ def test_zero2_reduce_scatter_off(tmpdir): ...@@ -745,3 +745,53 @@ def test_zero2_reduce_scatter_off(tmpdir):
model.step() model.step()
_helper(args=args, model=model, hidden_dim=hidden_dim) _helper(args=args, model=model, hidden_dim=hidden_dim)
@pytest.mark.parametrize('adam_type, torch_impl',
[('Adam',
True),
('Adam',
False),
('AdamW',
True),
('AdamW',
False)])
def test_fp16_adam_types(tmpdir, adam_type, torch_impl):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"fp16": {
"enabled": True,
"initial_scale_power": 10
},
"optimizer": {
"type": adam_type,
"torch_adam": torch_impl,
"params": {
"lr": 0.00015
}
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_fp16_adam_types(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device)
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_fp16_adam_types(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment