Unverified Commit ab5fc48f authored by Aidyn-A's avatar Aidyn-A Committed by GitHub
Browse files

Add grad check in test pipeline parallel fwd bwd (#1386)

* add grad check

* change assert

* minor changes

* revert unnecessary changes

* suggested changes

* fix tensor comparison

* small changes
parent da1f7f2f
...@@ -48,15 +48,24 @@ def get_init_weights_func(offset: int = 0): ...@@ -48,15 +48,24 @@ def get_init_weights_func(offset: int = 0):
return init_weights return init_weights
def get_target_loss(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> float: def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> float:
with torch.no_grad(): model = []
data = torch.ones(global_batch_shape, dtype=torch.float) data = torch.ones(global_batch_shape, dtype=torch.float)
for i in range(total_layers): for i in range(total_layers):
w = torch.ones(hidden_size, hidden_size) * (i + 1.0) / weight_coeff w = torch.ones((hidden_size, hidden_size)) * (i + 1.0) / weight_coeff
# don't need to care about transpose semantics as all values are the same b = torch.ones(hidden_size)
data = torch.matmul(w, data)
data += 1.0 w.requires_grad_()
return float((data.sum()/global_batch_shape[0]).item()) b.requires_grad_()
# don't need to care about transpose semantics as all values are the same
data = torch.matmul(w, data) + b
model.append([w, b])
loss = data.sum() / global_batch_shape[0]
loss.backward()
return loss, model
class PipelineParallelForwardBackwardTestBase: class PipelineParallelForwardBackwardTestBase:
...@@ -170,11 +179,24 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -170,11 +179,24 @@ class PipelineParallelForwardBackwardTestBase:
total_layers = pipeline_model_parallel_world_size total_layers = pipeline_model_parallel_world_size
if virtual_pipeline_model_parallel_size is not None: if virtual_pipeline_model_parallel_size is not None:
total_layers *= virtual_pipeline_model_parallel_size total_layers *= virtual_pipeline_model_parallel_size
target_loss = get_target_loss(global_batch_shape, hidden_size, total_layers) target_loss, target_model = get_target_loss_and_model(global_batch_shape, hidden_size, total_layers)
for loss_item in loss: for loss_item in loss:
x = loss_item['avg'] x = loss_item['avg']
torch.testing.assert_close(x / microbatch_size, target_loss*torch.ones_like(x)) torch.testing.assert_close(x.item() / microbatch_size, target_loss.item())
if not forward_only:
for vm_id, model_module in enumerate(model):
params = list(model_module.parameters())
rank = params[0].get_device()
offset = pipeline_model_parallel_world_size
param_id = rank // data_parallel_size + vm_id * offset
target_params = target_model[param_id]
torch.testing.assert_close(params[0].cpu(), target_params[0])
torch.testing.assert_close(params[1].cpu(), target_params[1])
torch.testing.assert_close(params[0].grad.cpu() / microbatch_size, target_params[0].grad)
torch.testing.assert_close(params[1].grad.cpu() / microbatch_size, target_params[1].grad)
if not forward_only: if not forward_only:
for m in model: for m in model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment