".github/workflows/compatiblity_test_on_pr.yml" did not exist on "6474e31556ae011410f29d8d0a2d80de67ed6956"
Unverified Commit e614aa34 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous...

[shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
parent df5e9c53
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
from .shard_config import ShardConfig
from .sharder import ModelSharder
from .shardformer import ShardFormer
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"]
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class GradientCheckpointConfig:
gradient_checkpointing_ratio: float = 0.0
def get_num_ckpt_layers(self, num_layers: int) -> int:
return int(self.gradient_checkpointing_ratio * num_layers)
@dataclass
class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
r"""
The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.
Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details.
It provides the following features:
1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing.
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
"""
"""
Args:
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check.
num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check.
num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check.
num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None.
Example 1:
num_stages = 8
num_layers = 80
num_model_chunks = 1
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0]
Example 2:
num_stages = 4
num_layers = 80
num_model_chunks = 2
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
# device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers
...
"""
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self):
if self._enable_gradient_checkpointing_ratio:
if not (0 <= self.gradient_checkpointing_ratio <= 1):
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
if self._enable_customized_ckpt_layers_per_stage:
assert (
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
)
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
assert all(
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
)
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
@property
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None
@property
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
raise RuntimeError("No checkpointed layers information is provided")
if self._enable_customized_ckpt_layers_per_stage:
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
assert num_ckpt_layers <= num_layers
return num_ckpt_layers
else:
return int(self.gradient_checkpointing_ratio * num_layers)
......@@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"]
......@@ -23,6 +25,7 @@ class ShardConfig:
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
......@@ -35,6 +38,7 @@ class ShardConfig:
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
parallel_output: bool = True
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
......
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union
......@@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
class OpenMoePolicy(Policy):
def config_sanity_check(self):
pass
......@@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
......@@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
......@@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
......@@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
class OpenMoeModelPolicy(OpenMoePolicy):
def __init__(self) -> None:
super().__init__()
......@@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
OpenMoeForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
OpenMoeForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
])
]
)
}
policy.update(new_item)
......@@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1):
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [{
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}]
}
]
return []
......@@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
......@@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
......@@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (past_key_values[idx] if past_key_values is not None else None)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
......@@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
router_z_loss = past_router_z_loss + router_z_loss
if stage_manager.is_last_stage():
return tuple([
return tuple(
[
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
])
]
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
......@@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
......@@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
if chunk_head == True:
def create_custom_forward(module):
def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
......@@ -522,8 +527,8 @@ class OpenMoePipelineForwards:
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
hidden_states[batch_idx : batch_idx + 1, :],
labels[batch_idx : batch_idx + 1, :],
)
logits = None
else:
......
......@@ -49,9 +49,9 @@ if HAS_LLAMA:
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=8,
hidden_size=32,
intermediate_size=64,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16,
......
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_t5_pipeline_distribution():
......@@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
......@@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
}
for i in range(num_test_cases):
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy()
policy.set_shard_config(shard_config)
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],
......
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.whisper import WhisperPolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_whisper_pipeline_distribution():
......@@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i],
......@@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
],
}
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i],
......
......@@ -5,6 +5,7 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import PipelineGradientCheckpointConfig
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
......@@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
if enable_gradient_checkpointing:
org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
......@@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
},
{
"tp_size": 1,
......@@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 4,
"use_lazy_init": False,
"precision": "fp32",
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
),
},
{
"tp_size": 4,
......@@ -189,6 +200,13 @@ def run_llama_test(test_config):
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2,
num_model_chunks=2,
num_model_layers=8,
num_ckpt_layers_per_stage=[0, 1, 2, 2],
),
},
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment