Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
This diff is collapsed.
...@@ -21,13 +21,13 @@ import pytest ...@@ -21,13 +21,13 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def simple_linears(pipeline_style): def simple_linears(pipeline_style):
def sum_grad(parameters): def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None]) return sum([p.grad.sum() for p in parameters if p.grad is not None])
...@@ -40,7 +40,7 @@ def simple_linears(pipeline_style): ...@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs = torch.rand(8, 1) inputs = torch.rand(8, 1)
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),) model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),)
# Without Pipe # Without MultiProcessPipe
outputs = model(inputs) outputs = model(inputs)
loss = outputs.mean() loss = outputs.mean()
loss.backward() loss.backward()
...@@ -54,20 +54,20 @@ def simple_linears(pipeline_style): ...@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad(model.parameters()) zero_grad(model.parameters())
# With Pipe # With MultiProcessPipe
model = Pipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4) model = MultiProcessPipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
outputs = model(inputs) outputs = model(inputs)
if model.group.rank() == 1: if model.group.rank() == 1:
loss = outputs.mean() loss = outputs.mean()
loss.backward() loss.backward()
grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else: else:
model.back_helper(outputs) model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment