Unverified Commit e614aa34 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous...

[shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
parent df5e9c53
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
from .shard_config import ShardConfig from .shard_config import ShardConfig
from .sharder import ModelSharder from .sharder import ModelSharder
from .shardformer import ShardFormer from .shardformer import ShardFormer
__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] __all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"]
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class GradientCheckpointConfig:
gradient_checkpointing_ratio: float = 0.0
def get_num_ckpt_layers(self, num_layers: int) -> int:
return int(self.gradient_checkpointing_ratio * num_layers)
@dataclass
class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
r"""
The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism.
Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism.
Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details.
It provides the following features:
1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing.
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
"""
"""
Args:
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check.
num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check.
num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check.
num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None.
Example 1:
num_stages = 8
num_layers = 80
num_model_chunks = 1
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0]
Example 2:
num_stages = 4
num_layers = 80
num_model_chunks = 2
num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11]
# device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers
...
"""
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self):
if self._enable_gradient_checkpointing_ratio:
if not (0 <= self.gradient_checkpointing_ratio <= 1):
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
if self._enable_customized_ckpt_layers_per_stage:
assert (
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
)
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
assert all(
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
)
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
@property
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None
@property
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
raise RuntimeError("No checkpointed layers information is provided")
if self._enable_customized_ckpt_layers_per_stage:
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
assert num_ckpt_layers <= num_layers
return num_ckpt_layers
else:
return int(self.gradient_checkpointing_ratio * num_layers)
...@@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup ...@@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"] __all__ = ["ShardConfig"]
...@@ -23,6 +25,7 @@ class ShardConfig: ...@@ -23,6 +25,7 @@ class ShardConfig:
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
""" """
tensor_parallel_process_group: Optional[ProcessGroup] = None tensor_parallel_process_group: Optional[ProcessGroup] = None
...@@ -35,6 +38,7 @@ class ShardConfig: ...@@ -35,6 +38,7 @@ class ShardConfig:
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
parallel_output: bool = True parallel_output: bool = True
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab # TODO padding vocab
# make_vocab_size_divisible_by: int = 128 # make_vocab_size_divisible_by: int = 128
......
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
...@@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] ...@@ -21,7 +20,6 @@ __all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
class OpenMoePolicy(Policy): class OpenMoePolicy(Policy):
def config_sanity_check(self): def config_sanity_check(self):
pass pass
...@@ -43,7 +41,8 @@ class OpenMoePolicy(Policy): ...@@ -43,7 +41,8 @@ class OpenMoePolicy(Policy):
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError( raise NotImplementedError(
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
...@@ -97,8 +96,8 @@ class OpenMoePolicy(Policy): ...@@ -97,8 +96,8 @@ class OpenMoePolicy(Policy):
else: else:
module = self.model.model module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls
...@@ -117,10 +116,10 @@ class OpenMoePolicy(Policy): ...@@ -117,10 +116,10 @@ class OpenMoePolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.norm) held_layers.append(module.norm)
...@@ -143,7 +142,6 @@ class OpenMoePolicy(Policy): ...@@ -143,7 +142,6 @@ class OpenMoePolicy(Policy):
class OpenMoeModelPolicy(OpenMoePolicy): class OpenMoeModelPolicy(OpenMoePolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy): ...@@ -169,21 +167,21 @@ class OpenMoeModelPolicy(OpenMoePolicy):
class OpenMoeForCausalLMPolicy(OpenMoePolicy): class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for casual lm
new_item = { new_item = {
OpenMoeForCausalLM: OpenMoeForCausalLM: ModulePolicyDescription(
ModulePolicyDescription(sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True), kwargs=dict(gather_output=True),
) )
]) ]
)
} }
policy.update(new_item) policy.update(new_item)
...@@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy): ...@@ -208,13 +206,17 @@ class OpenMoeForCausalLMPolicy(OpenMoePolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) if (
and self.pipeline_stage_manager.num_stages > 1): id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights # tie weights
return [{ return [
{
0: llama_model.embed_tokens.weight, 0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}] }
]
return [] return []
...@@ -247,12 +249,13 @@ class OpenMoePipelineForwards: ...@@ -247,12 +249,13 @@ class OpenMoePipelineForwards:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states output_hidden_states = (
if output_hidden_states is not None else self.config.output_hidden_states) output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
...@@ -320,7 +323,8 @@ class OpenMoePipelineForwards: ...@@ -320,7 +323,8 @@ class OpenMoePipelineForwards:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False use_cache = False
# decoder layers # decoder layers
...@@ -333,12 +337,11 @@ class OpenMoePipelineForwards: ...@@ -333,12 +337,11 @@ class OpenMoePipelineForwards:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
past_key_value = (past_key_values[idx] if past_key_values is not None else None) past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, output_attentions, None) return module(*inputs, output_attentions, None)
...@@ -384,14 +387,16 @@ class OpenMoePipelineForwards: ...@@ -384,14 +387,16 @@ class OpenMoePipelineForwards:
router_z_loss = past_router_z_loss + router_z_loss router_z_loss = past_router_z_loss + router_z_loss
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
return tuple([ return tuple(
[
hidden_states, hidden_states,
next_cache, next_cache,
all_hidden_states, all_hidden_states,
all_self_attns, all_self_attns,
router_aux_loss, router_aux_loss,
router_z_loss, router_z_loss,
]) ]
)
# always return dict for imediate stage # always return dict for imediate stage
return { return {
"hidden_states": hidden_states, "hidden_states": hidden_states,
...@@ -445,10 +450,11 @@ class OpenMoePipelineForwards: ...@@ -445,10 +450,11 @@ class OpenMoePipelineForwards:
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```""" ```"""
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states output_hidden_states = (
if output_hidden_states is not None else self.config.output_hidden_states) output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions: if output_attentions:
...@@ -504,7 +510,6 @@ class OpenMoePipelineForwards: ...@@ -504,7 +510,6 @@ class OpenMoePipelineForwards:
if chunk_head == True: if chunk_head == True:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
logits = module(inputs[0]) logits = module(inputs[0])
logits = logits.float() logits = logits.float()
...@@ -522,8 +527,8 @@ class OpenMoePipelineForwards: ...@@ -522,8 +527,8 @@ class OpenMoePipelineForwards:
for batch_idx in range(hidden_states.shape[0]): for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint( loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head), create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :], hidden_states[batch_idx : batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :], labels[batch_idx : batch_idx + 1, :],
) )
logits = None logits = None
else: else:
......
...@@ -49,9 +49,9 @@ if HAS_LLAMA: ...@@ -49,9 +49,9 @@ if HAS_LLAMA:
loss_fn_for_seq_classification = lambda output: output["logits"].mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig( config = LlamaConfig(
num_hidden_layers=4, num_hidden_layers=8,
hidden_size=128, hidden_size=32,
intermediate_size=256, intermediate_size=64,
num_attention_heads=4, num_attention_heads=4,
max_position_embeddings=128, max_position_embeddings=128,
num_labels=16, num_labels=16,
......
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.t5 import T5BasePolicy from colossalai.shardformer.policies.t5 import T5BasePolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_t5_pipeline_distribution(): def test_t5_pipeline_distribution():
...@@ -10,7 +29,10 @@ def test_t5_pipeline_distribution(): ...@@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
} }
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy() policy = T5BasePolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases): for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_t5_layers( _, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i], test_dict["num_encoder_layers"][i],
...@@ -35,7 +57,10 @@ def test_t5_pipeline_layers(): ...@@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
} }
for i in range(num_test_cases): for i in range(num_test_cases):
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = T5BasePolicy() policy = T5BasePolicy()
policy.set_shard_config(shard_config)
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i], test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i], test_dict["num_decoder_layers"][i],
......
import random
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.whisper import WhisperPolicy from colossalai.shardformer.policies.whisper import WhisperPolicy
from colossalai.shardformer.shard.shard_config import ShardConfig
class _ShardConfig(ShardConfig):
def __post_init__(self):
pass
class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
@property
def num_stages(self):
return random.randint(5, 10)
def test_whisper_pipeline_distribution(): def test_whisper_pipeline_distribution():
...@@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution(): ...@@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
} }
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy() policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases): for i in range(num_test_cases):
_, decoder_starting_stage = policy.distribute_whisper_layers( _, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i], test_dict["num_encoder_layers"][i],
...@@ -34,7 +56,10 @@ def test_whisper_pipeline_layers(): ...@@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
], ],
} }
stage_manager = _PipelineStageManager()
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
policy = WhisperPolicy() policy = WhisperPolicy()
policy.set_shard_config(shard_config)
for i in range(num_test_cases): for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i], test_dict["num_encoder_layers"][i],
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import PipelineGradientCheckpointConfig
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
...@@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" ...@@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config model_fn, loss_fn, test_config
) )
if enable_gradient_checkpointing:
org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
...@@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
}, },
{ {
"tp_size": 1, "tp_size": 1,
...@@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 4, "num_microbatches": 4,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
),
}, },
{ {
"tp_size": 4, "tp_size": 4,
...@@ -189,6 +200,13 @@ def run_llama_test(test_config): ...@@ -189,6 +200,13 @@ def run_llama_test(test_config):
"precision": "fp16", "precision": "fp16",
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2,
num_model_chunks=2,
num_model_layers=8,
num_ckpt_layers_per_stage=[0, 1, 2, 2],
),
}, },
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment