Unverified Commit 42e44149 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: remove pipelined_backward (#362)

parent 7fdd7ecf
...@@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers): ...@@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers):
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
pipelined_backward=False,
checkpoint=args.checkpoint, checkpoint=args.checkpoint,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -523,7 +523,6 @@ def run_mp_worker(args, available_workers): ...@@ -523,7 +523,6 @@ def run_mp_worker(args, available_workers):
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint, checkpoint=args.checkpoint,
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"], # TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
) )
...@@ -592,7 +591,6 @@ parser.add_argument( ...@@ -592,7 +591,6 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe" "--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
) )
parser.add_argument("--pipelined-backward", action="store_true", help="Pipelined backward pass")
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.") parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument( parser.add_argument(
......
...@@ -39,6 +39,10 @@ class PartitionInfo: ...@@ -39,6 +39,10 @@ class PartitionInfo:
class AsyncPipe(MultiProcessPipe): class AsyncPipe(MultiProcessPipe):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.pipelined_backward = False
def create_pipeline(self) -> None: def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops. # The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
......
...@@ -118,13 +118,6 @@ class MultiProcessPipe(Module): ...@@ -118,13 +118,6 @@ class MultiProcessPipe(Module):
whether to use deferred BatchNorm moving statistics (default: whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more :data:`False`, see :class:`DeferredBatchNorm` for more
details) details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
Raises: Raises:
TypeError: TypeError:
...@@ -174,7 +167,6 @@ class MultiProcessPipe(Module): ...@@ -174,7 +167,6 @@ class MultiProcessPipe(Module):
chunks: int = chunks, chunks: int = chunks,
checkpoint: str = checkpoint, checkpoint: str = checkpoint,
deferred_batch_norm: bool = False, deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -183,13 +175,17 @@ class MultiProcessPipe(Module): ...@@ -183,13 +175,17 @@ class MultiProcessPipe(Module):
if checkpoint not in ["always", "except_last", "never"]: if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
self.balance = list(balance) self.balance = list(balance)
verify_module(module) verify_module(module)
check_balance(module, self.balance) check_balance(module, self.balance)
self.chunks = chunks self.chunks = chunks
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.pipeline: Optional[MultiProcessPipeline] self.pipeline: Optional[MultiProcessPipeline]
self.lock = threading.Lock() self.lock = threading.Lock()
...@@ -227,12 +223,6 @@ class MultiProcessPipe(Module): ...@@ -227,12 +223,6 @@ class MultiProcessPipe(Module):
del module del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def create_pipeline(self) -> None: def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops. # The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
......
...@@ -443,7 +443,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -443,7 +443,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
worker_map=worker_map, worker_map=worker_map,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
chunks=chunk_size, chunks=chunk_size,
pipelined_backward=True,
).cuda() ).cuda()
torch.distributed.barrier() torch.distributed.barrier()
pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group()) pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group())
......
...@@ -259,15 +259,9 @@ def checkpoint_mode(pipe_class): ...@@ -259,15 +259,9 @@ def checkpoint_mode(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1) input = torch.rand(2, 1)
always = pipe_class( always = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="always",)
model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="always", pipelined_backward=False, except_last = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="except_last",)
) never = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="never",)
except_last = pipe_class(
model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="except_last", pipelined_backward=False,
)
never = pipe_class(
model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="never", pipelined_backward=False,
)
always_output = always(input) always_output = always(input)
except_last_output = except_last(input) except_last_output = except_last(input)
...@@ -306,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class): ...@@ -306,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_eval(pipe_class): def checkpoint_eval(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
input = torch.rand(2, 1) input = torch.rand(2, 1)
def find_grad_fn(grad_fn, name): def find_grad_fn(grad_fn, name):
...@@ -343,9 +337,7 @@ def checkpoint_non_float_input(pipe_class): ...@@ -343,9 +337,7 @@ def checkpoint_non_float_input(pipe_class):
return input[0] * 2 return input[0] * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = pipe_class( model = pipe_class(model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always",)
model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always", pipelined_backward=False,
)
input = torch.rand(1, requires_grad=True) input = torch.rand(1, requires_grad=True)
output = model(input) output = model(input)
...@@ -456,7 +448,7 @@ def input_pair(pipe_class): ...@@ -456,7 +448,7 @@ def input_pair(pipe_class):
return (self.fc_a(a), self.fc_b(b)) return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two()) model = nn.Sequential(Two())
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
a = torch.rand(10, 1, requires_grad=True) a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True)
...@@ -482,7 +474,7 @@ def input_singleton(pipe_class): ...@@ -482,7 +474,7 @@ def input_singleton(pipe_class):
return (self.fc(a),) return (self.fc(a),)
model = nn.Sequential(One()) model = nn.Sequential(One())
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
a = torch.rand(10, 1, requires_grad=True) a = torch.rand(10, 1, requires_grad=True)
...@@ -766,7 +758,7 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class): ...@@ -766,7 +758,7 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe])
def pipelined_backward(pipe_class): def pipelined_backward(pipe_class):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment