Unverified Commit da1f7f2f authored by eqy's avatar eqy Committed by GitHub
Browse files

Test `len(model) > 1` in `test_pipelining_with_interleaving` (#1384)



* check in

* type

* cleanup

* cleanup

* fix function call

* Apply suggestions from code review
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 891d57d3
......@@ -37,27 +37,26 @@ logging.getLogger("apex").setLevel(logging.WARNING)
weight_coeff = 1024
@torch.no_grad()
def init_weights(m):
rank = torch.distributed.get_rank()
if isinstance(m, torch.nn.Linear):
m.weight.fill_((rank + 1.0) / weight_coeff)
m.bias.fill_(1.0)
def get_init_weights_func(offset: int = 0):
@torch.no_grad()
def init_weights(m):
rank = parallel_state.get_pipeline_model_parallel_rank()
if isinstance(m, torch.nn.Linear):
m.weight.fill_((rank + offset + 1.0) / weight_coeff)
m.bias.fill_(1.0)
return init_weights
def get_target_loss(hidden_size: int, microbatch_size: int, parallel_model_world_size: int, world_size: int) -> float:
layers_per_rank = world_size // parallel_model_world_size
data = torch.arange(start = 0, end = layers_per_rank, dtype = torch.int) + 1
w = (torch.arange(world_size, dtype = torch.float) + 1.0) / weight_coeff
b = torch.ones(world_size, dtype = torch.int)
w = hidden_size * w
for s_id in range(0, world_size, layers_per_rank):
e_id = s_id+layers_per_rank
data = w[s_id:e_id] * data + b[s_id:e_id]
return hidden_size * hidden_size * torch.sum(data).item() * microbatch_size / layers_per_rank
def get_target_loss(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> float:
with torch.no_grad():
data = torch.ones(global_batch_shape, dtype=torch.float)
for i in range(total_layers):
w = torch.ones(hidden_size, hidden_size) * (i + 1.0) / weight_coeff
# don't need to care about transpose semantics as all values are the same
data = torch.matmul(w, data)
data += 1.0
return float((data.sum()/global_batch_shape[0]).item())
class PipelineParallelForwardBackwardTestBase:
......@@ -82,7 +81,11 @@ class PipelineParallelForwardBackwardTestBase:
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
self.assertIsNotNone(virtual_pipeline_model_parallel_size)
self.assertGreater(virtual_pipeline_model_parallel_size, 1)
dtype_options = self.dtypes or [torch.float32] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product(
dtype_options, self.deallocate_options,
):
......@@ -121,7 +124,9 @@ class PipelineParallelForwardBackwardTestBase:
self.HIDDEN_SIZE,
)
batch =(((self.rank + 1) * torch.ones(global_batch_shape)).cuda(), )
batch = None
if parallel_state.is_pipeline_first_stage():
batch = (torch.ones(global_batch_shape, dtype=torch.float).cuda(), )
model = build_model(
testing_utils.model_provider_func,
......@@ -131,8 +136,10 @@ class PipelineParallelForwardBackwardTestBase:
hidden_size=self.HIDDEN_SIZE,
)
for model_module in model:
model_module.apply(init_weights)
offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
for idx, model_module in enumerate(model):
model_module.apply(get_init_weights_func(idx*offset))
_param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
......@@ -160,11 +167,14 @@ class PipelineParallelForwardBackwardTestBase:
if dtype == torch.float32:
hidden_size = self.HIDDEN_SIZE
microbatch_size = self.MICRO_BATCH_SIZE
target_loss = get_target_loss(hidden_size, microbatch_size, pipeline_model_parallel_world_size, self.world_size)
total_layers = pipeline_model_parallel_world_size
if virtual_pipeline_model_parallel_size is not None:
total_layers *= virtual_pipeline_model_parallel_size
target_loss = get_target_loss(global_batch_shape, hidden_size, total_layers)
for loss_item in loss:
x = loss_item['avg']
torch.testing.assert_close(x, target_loss*torch.ones_like(x))
torch.testing.assert_close(x / microbatch_size, target_loss*torch.ones_like(x))
if not forward_only:
for m in model:
......@@ -203,12 +213,12 @@ class PipelineParallelForwardBackwardTestBase:
def test_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, None
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
def test_pipelining_with_interleaving_inference(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, None
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment