Unverified Commit 46e09165 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix
parent 2a0558d8
...@@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): ...@@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
Returns: Returns:
None None
""" """
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism: if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None: if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group. # Synchronize provided gradient tensors across the tensor parallelism group.
...@@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns: Returns:
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs) super().backward(loss, *args, **kwargs)
...@@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns: Returns:
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad) super().backward_by_grad(tensor, grad)
...@@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
Returns: Returns:
None None
""" """
# Call the superclass `_sync_grad` method to synchronize gradients. # Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad() super()._sync_grad()
...@@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return True return True
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
return False return True
def control_checkpoint_io(self) -> bool: def control_checkpoint_io(self) -> bool:
return True return True
...@@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase):
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs
# Synchronize the grads of shared parameters of the model. # Synchronize the grads of shared parameters of the model.
model.sync_shared_params() model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model. # Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads() model.sync_sp_grads()
...@@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError assert (
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.testing import assert_close
from torch.utils.data import Dataset
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
...@@ -11,9 +14,33 @@ from colossalai.fx import is_compatible_with_meta ...@@ -11,9 +14,33 @@ from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device, set_seed
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
set_seed(42)
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@clear_cache_before_run() @clear_cache_before_run()
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try: try:
...@@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): ...@@ -85,10 +112,145 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
@parameterize(
"test_args",
[
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 0,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 8,
"num_steps": 4,
"tp": 1,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 1,
"num_steps": 4,
"tp": 2,
"pp": 1,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 1,
"zero": 2,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 1,
"num_steps": 4,
"tp": 2,
"pp": 1,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 1,
"zero": 0,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
],
)
def run_grad_acc_test(test_args):
model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()))
model = model_fn()
optimizer = HybridAdam(model.parameters())
origin_model = copy.deepcopy(model).cuda()
origin_optimizer = HybridAdam(origin_model.parameters())
plugin = HybridParallelPlugin(
tp_size=test_args["tp"],
pp_size=test_args["pp"],
pp_style=test_args["pp_style"],
zero_stage=test_args["zero"],
num_model_chunks=test_args["num_model_chunks"],
enable_fused_normalization=True,
num_microbatches=test_args["num_microbatches"],
precision=test_args["precision"],
)
booster = Booster(plugin=plugin)
dataset = RandomDataset(
num_samples=test_args["batch_size"] * test_args["num_steps"] * plugin.dp_size,
max_length=test_args["max_length"],
vocab_size=model.config.vocab_size,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=test_args["batch_size"], shuffle=True, drop_last=True)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
grad_accu_step = test_args["gradient_accumulation_step"]
for step, batch in enumerate(dataloader):
batch = move_to_cuda(batch)
# train origin model
origin_output = origin_model(**batch)
origin_loss = origin_output[0] / grad_accu_step
origin_loss.backward()
if (step + 1) % grad_accu_step != 0 and test_args["zero"] != 2:
ctx = booster.no_sync(model, optimizer)
else:
ctx = nullcontext()
with ctx:
if plugin.stage_manager is not None:
batch = iter([batch])
booster.execute_pipeline(
batch,
model,
criterion=lambda outputs, inputs: outputs[0] / grad_accu_step,
optimizer=optimizer,
return_loss=False,
)
else:
outputs = model(**batch)
loss = outputs[0] / grad_accu_step
booster.backward(loss, optimizer)
if (step + 1) % grad_accu_step == 0:
# update origin model weight
origin_optimizer.step()
origin_optimizer.zero_grad()
# update sharded model
optimizer.step()
optimizer.zero_grad()
# tricky code here, shard the origin model inorder to check the parameters in the same stage.
origin_model, origin_optimizer, _, dataloader, _ = booster.boost(
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
def run_dist(rank, world_size, port, early_stop: bool = True): def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_3d_plugin(early_stop=early_stop) check_3d_plugin(early_stop=early_stop)
run_grad_acc_test()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment