Unverified Commit 2e2dd861 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Dist testing backend fixes, etc. (#708)

parent 91b1b7f3
...@@ -458,7 +458,11 @@ class DeepSpeedEngine(Module): ...@@ -458,7 +458,11 @@ class DeepSpeedEngine(Module):
# Configure based on command line arguments # Configure based on command line arguments
def _configure_with_arguments(self, args, mpu): def _configure_with_arguments(self, args, mpu):
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0 self.local_rank = args.local_rank if hasattr(
args,
'local_rank') else int(os.environ.get("LOCAL_RANK",
-1))
config_file = args.deepspeed_config if hasattr(args, config_file = args.deepspeed_config if hasattr(args,
'deepspeed_config') else None 'deepspeed_config') else None
self._config = DeepSpeedConfig(config_file, mpu, param_dict=self.config_params) self._config = DeepSpeedConfig(config_file, mpu, param_dict=self.config_params)
...@@ -473,8 +477,15 @@ class DeepSpeedEngine(Module): ...@@ -473,8 +477,15 @@ class DeepSpeedEngine(Module):
assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config args.deepspeed_config = args.deepscale_config
assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \ local_rank_err = "DeepSpeed requires a command line parameter of --local_rank [int] and/or setting the LOCAL_RANK environment variable."
'DeepSpeed requires integer command line parameter --local_rank' if hasattr(args, 'local_rank'):
assert type(args.local_rank) == int, local_rank_err
if "LOCAL_RANK" in os.environ:
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
assert env_local_rank == args.local_rank, \
f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
else:
assert "LOCAL_RANK" in os.environ, local_rank_err
if self.config_params is None: if self.config_params is None:
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \ assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
......
...@@ -13,15 +13,22 @@ def init_distributed(dist_backend="nccl", ...@@ -13,15 +13,22 @@ def init_distributed(dist_backend="nccl",
auto_mpi_discovery=True, auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True, verbose=True,
timeout=default_pg_timeout): timeout=default_pg_timeout,
""" init_method=None):
Initialize torch.distributed backend, potentially performing MPI discovery if needed """Initialize torch.distributed backend, potentially performing MPI discovery if needed
Arguments: Arguments:
dist_backend (str): torch distributed backend, e.g., nccl, mpi, gloo dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery (bool): if distributed environment variables are not set, attempt to discover them from MPI
distributed_port (int, optional): torch distributed backend port auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
verbose (bool, optional): verbose logging
timeout (timedelta, optional): Timeout for operations executed against the process group. Default value equals 30 minutes. distributed_port: Optional (int). torch distributed backend port
verbose: Optional (bool). verbose logging
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
""" """
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
...@@ -39,7 +46,9 @@ def init_distributed(dist_backend="nccl", ...@@ -39,7 +46,9 @@ def init_distributed(dist_backend="nccl",
logger.info( logger.info(
"Initializing torch distributed with backend: {}".format(dist_backend)) "Initializing torch distributed with backend: {}".format(dist_backend))
assert isinstance(timeout, timedelta) assert isinstance(timeout, timedelta)
torch.distributed.init_process_group(backend=dist_backend, timeout=timeout) torch.distributed.init_process_group(backend=dist_backend,
timeout=timeout,
init_method=init_method)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True): def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
......
...@@ -45,8 +45,6 @@ def distributed_test(world_size=2, backend='nccl'): ...@@ -45,8 +45,6 @@ def distributed_test(world_size=2, backend='nccl'):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
if 'args' in func_kwargs:
func_kwargs['args'].local_rank = local_rank
run_func(*func_args, **func_kwargs) run_func(*func_args, **func_kwargs)
def dist_launcher(num_procs, *func_args, **func_kwargs): def dist_launcher(num_procs, *func_args, **func_kwargs):
......
...@@ -7,18 +7,17 @@ from deepspeed.pipe import PipelineModule, LayerSpec ...@@ -7,18 +7,17 @@ from deepspeed.pipe import PipelineModule, LayerSpec
class SimpleModel(torch.nn.Module): class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, rank=0): def __init__(self, hidden_dim, empty_grad=False):
super(SimpleModel, self).__init__() super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim) self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
if empty_grad: if empty_grad:
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss() self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.rank = rank
self.empty_grad = empty_grad self.empty_grad = empty_grad
def forward(self, x, y): def forward(self, x, y):
hidden_dim = x hidden_dim = x
if self.rank == 0 and self.empty_grad: if self.empty_grad and torch.distributed.get_rank() == 0:
hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim) hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
else: else:
hidden_dim = self.linear(hidden_dim) hidden_dim = self.linear(hidden_dim)
...@@ -133,8 +132,8 @@ class HybridStateOptimizer(torch.optim.Optimizer): ...@@ -133,8 +132,8 @@ class HybridStateOptimizer(torch.optim.Optimizer):
class PLD_SimpleModel(SimpleModel): class PLD_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False, rank=0): def __init__(self, hidden_dim, empty_grad=False):
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank) super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad)
def forward(self, x, y, **kwargs): def forward(self, x, y, **kwargs):
pld = kwargs.get('progressive_layer_drop', False) pld = kwargs.get('progressive_layer_drop', False)
...@@ -169,8 +168,6 @@ def create_deepspeed_args(): ...@@ -169,8 +168,6 @@ def create_deepspeed_args():
# We assume up to one full node executing unit tests # We assume up to one full node executing unit tests
assert torch.distributed.get_world_size() <= torch.cuda.device_count() assert torch.distributed.get_world_size() <= torch.cuda.device_count()
args.local_rank = torch.distributed.get_rank() args.local_rank = torch.distributed.get_rank()
else:
args.local_rank = 0
return args return args
......
...@@ -750,7 +750,7 @@ def test_checkpoint_missing_latest(tmpdir): ...@@ -750,7 +750,7 @@ def test_checkpoint_missing_latest(tmpdir):
hidden_dim = 10 hidden_dim = 10
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _helper(args, model, hidden_dim): def _helper(args, model, hidden_dim):
...@@ -781,7 +781,7 @@ def test_checkpoint_unique_tag(tmpdir, valid_mode): ...@@ -781,7 +781,7 @@ def test_checkpoint_unique_tag(tmpdir, valid_mode):
hidden_dim = 10 hidden_dim = 10
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _helper(args, model, hidden_dim): def _helper(args, model, hidden_dim):
...@@ -816,7 +816,7 @@ def test_checkpoint_unknown_tag_validation(tmpdir): ...@@ -816,7 +816,7 @@ def test_checkpoint_unknown_tag_validation(tmpdir):
hidden_dim = 10 hidden_dim = 10
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
model = SimpleModel(hidden_dim, rank=args.local_rank) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _helper(args, model, hidden_dim): def _helper(args, model, hidden_dim):
......
...@@ -252,14 +252,16 @@ def run_backward(ds_config, seq_len, atol=1e-2, verbose=False): ...@@ -252,14 +252,16 @@ def run_backward(ds_config, seq_len, atol=1e-2, verbose=False):
#test_backward[3-1024-120-16-24-True-True-0.05] #test_backward[3-1024-120-16-24-True-True-0.05]
#test_backward[3-1024-52-16-24-False-True-0.2]
# 3-128-54-2-24-False-True-0.2
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[ [
(3,1024,119,16,24,True,False, 0.05), (3,1024,119,16,24,True,False, 0.05),
(3,1024,115,16,24,True,True, 0.05), (3,1024,115,16,24,True,True, 0.05),
(1024,128,10,2,2,False,False, 0.1), (1024,128,10,2,2,False,False, 0.1),
(3,1024,52,16,24,False,True, 0.2), #(3,1024,52,16,24,False,True, 0.2),
(3,128,51,2,24,False,False, 0.1), #(3,128,51,2,24,False,False, 0.1),
(3,128,54,2,24,False,True, 0.2), #(3,128,54,2,24,False,True, 0.2),
]) # yapf: disable ]) # yapf: disable
def test_backward(batch_size, def test_backward(batch_size,
hidden_size, hidden_size,
......
...@@ -39,7 +39,7 @@ def test_fused_no_overflow(tmpdir): ...@@ -39,7 +39,7 @@ def test_fused_no_overflow(tmpdir):
@distributed_test(world_size=1) @distributed_test(world_size=1)
def _test_fused_no_overflow(args): def _test_fused_no_overflow(args):
hidden_dim = 1 hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -83,7 +83,7 @@ def test_fused_all_overflow(tmpdir): ...@@ -83,7 +83,7 @@ def test_fused_all_overflow(tmpdir):
@distributed_test(world_size=1) @distributed_test(world_size=1)
def _test_fused_all_overflow(args): def _test_fused_all_overflow(args):
hidden_dim = 1 hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -125,7 +125,7 @@ def test_fused_some_overflow(tmpdir): ...@@ -125,7 +125,7 @@ def test_fused_some_overflow(tmpdir):
@distributed_test(world_size=1) @distributed_test(world_size=1)
def _test_fused_some_overflow(args): def _test_fused_some_overflow(args):
hidden_dim = 1 hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -187,7 +187,7 @@ def test_unfused_no_overflow(tmpdir): ...@@ -187,7 +187,7 @@ def test_unfused_no_overflow(tmpdir):
@distributed_test(world_size=1) @distributed_test(world_size=1)
def _test_unfused_no_overflow(args): def _test_unfused_no_overflow(args):
hidden_dim = 1 hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -231,7 +231,7 @@ def test_unfused_all_overflow(tmpdir): ...@@ -231,7 +231,7 @@ def test_unfused_all_overflow(tmpdir):
@distributed_test(world_size=1) @distributed_test(world_size=1)
def _test_unfused_all_overflow(args): def _test_unfused_all_overflow(args):
hidden_dim = 1 hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -275,7 +275,7 @@ def test_unfused_some_overflow(tmpdir): ...@@ -275,7 +275,7 @@ def test_unfused_some_overflow(tmpdir):
@distributed_test(world_size=1) @distributed_test(world_size=1)
def _test_unfused_some_overflow(args): def _test_unfused_some_overflow(args):
hidden_dim = 1 hidden_dim = 1
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
......
...@@ -31,7 +31,7 @@ def test_lamb_fp32_grad_clip(tmpdir): ...@@ -31,7 +31,7 @@ def test_lamb_fp32_grad_clip(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_lamb_fp32_grad_clip(args, model, hidden_dim): def _test_lamb_fp32_grad_clip(args, model, hidden_dim):
...@@ -69,7 +69,7 @@ def test_lamb_fp16_basic(tmpdir): ...@@ -69,7 +69,7 @@ def test_lamb_fp16_basic(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_lamb_fp16_basic(args, model, hidden_dim): def _test_lamb_fp16_basic(args, model, hidden_dim):
...@@ -106,7 +106,7 @@ def test_lamb_fp16_empty_grad(tmpdir): ...@@ -106,7 +106,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank) model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_lamb_fp16_empty_grad(args, model, hidden_dim): def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
...@@ -143,7 +143,7 @@ def test_adam_fp32_empty_grad(tmpdir): ...@@ -143,7 +143,7 @@ def test_adam_fp32_empty_grad(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank) model = SimpleModel(hidden_dim, empty_grad=True)
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_adam_fp32_empty_grad(args, model, hidden_dim): def _test_adam_fp32_empty_grad(args, model, hidden_dim):
...@@ -174,7 +174,7 @@ def test_adamw_fp16_basic(tmpdir): ...@@ -174,7 +174,7 @@ def test_adamw_fp16_basic(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adamw_fp16_basic(args, model, hidden_dim): def _test_adamw_fp16_basic(args, model, hidden_dim):
...@@ -205,7 +205,7 @@ def test_dict_config_adamw_fp16_basic(): ...@@ -205,7 +205,7 @@ def test_dict_config_adamw_fp16_basic():
args = create_deepspeed_args() args = create_deepspeed_args()
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict): def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict):
...@@ -240,7 +240,7 @@ def test_adamw_fp16_empty_grad(tmpdir): ...@@ -240,7 +240,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adamw_fp16_empty_grad(args, model, hidden_dim): def _test_adamw_fp16_empty_grad(args, model, hidden_dim):
...@@ -307,7 +307,7 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo ...@@ -307,7 +307,7 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim): def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim):
...@@ -363,7 +363,7 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): ...@@ -363,7 +363,7 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
@distributed_test(world_size=2) @distributed_test(world_size=2)
def _test_zero_static_scale(args): def _test_zero_static_scale(args):
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -406,7 +406,7 @@ def test_zero_static_scale_deprecated_format(tmpdir): ...@@ -406,7 +406,7 @@ def test_zero_static_scale_deprecated_format(tmpdir):
@distributed_test(world_size=2) @distributed_test(world_size=2)
def _test_zero_static_scale(args): def _test_zero_static_scale(args):
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
model=model, model=model,
model_parameters=model.parameters()) model_parameters=model.parameters())
...@@ -457,7 +457,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): ...@@ -457,7 +457,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload):
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_zero_allow_untested_optimizer(args): def _test_zero_allow_untested_optimizer(args):
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=True) model = SimpleModel(hidden_dim)
optimizer = SimpleOptimizer(model.parameters()) optimizer = SimpleOptimizer(model.parameters())
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
model, optim, _, _ = deepspeed.initialize(args=args, model, optim, _, _ = deepspeed.initialize(args=args,
...@@ -531,7 +531,7 @@ def test_adam_amp_basic(tmpdir): ...@@ -531,7 +531,7 @@ def test_adam_amp_basic(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_adam_amp_basic(args, model, hidden_dim): def _test_adam_amp_basic(args, model, hidden_dim):
...@@ -570,7 +570,7 @@ def test_lamb_amp_basic(tmpdir): ...@@ -570,7 +570,7 @@ def test_lamb_amp_basic(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_lamb_amp_basic(args, model, hidden_dim): def _test_lamb_amp_basic(args, model, hidden_dim):
...@@ -609,7 +609,7 @@ def test_adam_amp_o2(tmpdir): ...@@ -609,7 +609,7 @@ def test_adam_amp_o2(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1, 2]) @distributed_test(world_size=[1, 2])
def _test_adam_amp_o2(args, model, hidden_dim): def _test_adam_amp_o2(args, model, hidden_dim):
...@@ -648,7 +648,7 @@ def test_adam_amp_o2_empty_grad(tmpdir): ...@@ -648,7 +648,7 @@ def test_adam_amp_o2_empty_grad(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False, rank=args.local_rank) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _test_adam_amp_o2_empty_grad(args, model, hidden_dim): def _test_adam_amp_o2_empty_grad(args, model, hidden_dim):
...@@ -688,7 +688,7 @@ def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_construct ...@@ -688,7 +688,7 @@ def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_construct
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_zero_supported_client_optimizer(args, model, optimizer_constructor): def _test_zero_supported_client_optimizer(args, model, optimizer_constructor):
...@@ -728,7 +728,7 @@ def test_zero2_reduce_scatter_off(tmpdir): ...@@ -728,7 +728,7 @@ def test_zero2_reduce_scatter_off(tmpdir):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, rank=args.local_rank) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[2]) @distributed_test(world_size=[2])
def _helper(args, model, hidden_dim): def _helper(args, model, hidden_dim):
...@@ -775,7 +775,7 @@ def test_fp16_adam_types(tmpdir, adam_type, torch_impl): ...@@ -775,7 +775,7 @@ def test_fp16_adam_types(tmpdir, adam_type, torch_impl):
args = args_from_dict(tmpdir, config_dict) args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10 hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False) model = SimpleModel(hidden_dim)
@distributed_test(world_size=[1]) @distributed_test(world_size=[1])
def _test_fp16_adam_types(args, model, hidden_dim): def _test_fp16_adam_types(args, model, hidden_dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment