Unverified Commit 8eaa3622 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Fix for using synthetic data + remove unused flags (#485)



* smal fix, remove unused flags

* remove usused flag

* add back max_batch flag

* adding back lazy_construction

* adding back lazy_construction

* add missing device arg
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
parent 84cec202
......@@ -23,9 +23,8 @@ from torch.optim import Adam
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map
MPI_PORT = 29500
......@@ -48,10 +47,7 @@ def get_model_and_optimizer(args, device, benchmark_config, model_config):
lr = benchmark_config["lr"]
def make_adam(params):
if args.ddp_zero:
return OSS(params=params, optim=Adam, group=get_data_parallel_group(), lr=lr)
else:
return Adam(params, lr=lr)
return Adam(params, lr=lr)
optimizer = make_adam
return model, optimizer
......@@ -156,17 +152,10 @@ def train(model_config, model, benchmark_config, model_specs, args):
pipe_group = model.group if hasattr(model, "group") else None
if args.ddp_zero:
model = DDP(
model,
device_ids=[torch.cuda.current_device()],
process_group=get_data_parallel_group(),
find_unused_parameters=False,
)
# TODO(anj-s): Avoid sending fake data to all replicas except the first and last one.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
lm_dataloader, _, _ = get_synthetic_dataloaders(args, benchmark_config, model_specs)
lm_dataloader, _, _ = get_synthetic_dataloaders(args, device, benchmark_config, model_specs)
total_tokens = 0
total_tokens_per_log_interval = 0
......@@ -193,7 +182,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
optimizer.zero_grad()
try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
if pipe_group is None or pipe_group.rank() == 0:
tmp = source.to(get_device(model, 0))
output = model(tmp)
else:
......@@ -206,17 +195,10 @@ def train(model_config, model, benchmark_config, model_specs, args):
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
if args.ddp_zero:
ddp_group = get_data_parallel_group()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
loss /= ddp_group.size()
loss.backward()
del target
else:
if args.ddp_zero:
model.module.back_helper(output)
else:
model.back_helper(output)
model.back_helper(output)
del output
......@@ -357,7 +339,7 @@ def generate_balance(num_devices, num_layers):
return balance
def get_synthetic_dataloaders(args, benchmark_config, model_specs):
def get_synthetic_dataloaders(args, device, benchmark_config, model_specs):
"""Returns dataloader for synthetic data."""
if args.model_name == "lm":
......@@ -520,15 +502,13 @@ parser.add_argument("--multiprocess", action="store_true", help="Runs single pro
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment